Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
eac1753d
Commit
eac1753d
authored
Sep 15, 2021
by
Qianfeng Zhang
Browse files
Avoid convert to compType from dstDataType before writting the output value
parent
5a9f6308
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
27 deletions
+63
-27
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
...sor_operation/gridwise_generic_2d_reduction_blockwise.hpp
+21
-9
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
...ation/gridwise_generic_2d_reduction_direct_threadwise.hpp
+21
-9
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
...eration/gridwise_generic_2d_reduction_direct_warpwise.hpp
+21
-9
No files found.
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
View file @
eac1753d
...
@@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
auto
threadwise_dst_load
=
auto
threadwise_dst_load
=
...
@@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise
threadwise_dst_load
.
Run
(
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dst1dDesc
,
dst_global_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
accu
Value_buf
(
I0
)
+=
type_convert
<
compType
>
{}(
priorDstValue_buf
[
I0
]
*
beta
)
;
dst
Value_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
}
auto
threadwise_dst_store
=
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
comp
Type
,
ThreadwiseTensorSliceTransfer_v1r3
<
dstData
Type
,
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
ReducedDataDesc
),
dst1dDescType
,
dst1dDescType
,
...
@@ -218,7 +222,7 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -218,7 +222,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index
(
block_global_1d_id
));
make_multi_index
(
block_global_1d_id
));
threadwise_dst_store
.
Run
(
threadwise_dst_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accu
Value_buf
,
dst1dDesc
,
dst_global_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
dst
Value_buf
,
dst1dDesc
,
dst_global_buf
);
}
}
};
};
...
@@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
auto
threadwise_dst_load
=
auto
threadwise_dst_load
=
...
@@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple
(
I0
),
make_tuple
(
I0
),
priorDstValue_buf
);
priorDstValue_buf
);
accu
Value_buf
(
I0
)
+=
type_convert
<
compType
>
{}(
priorDstValue_buf
[
I0
]
*
beta
)
;
dst
Value_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
}
auto
threadwise_dst_val_store
=
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
comp
Type
,
ThreadwiseTensorSliceTransfer_v1r3
<
dstData
Type
,
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
ReducedDataDesc
),
dst1dDescType
,
dst1dDescType
,
...
@@ -400,7 +408,7 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -400,7 +408,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index
(
block_global_1d_id
));
make_multi_index
(
block_global_1d_id
));
threadwise_dst_val_store
.
Run
(
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accu
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
dst
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
}
}
...
@@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
auto
threadwise_dst_load
=
auto
threadwise_dst_load
=
...
@@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple
(
I0
),
make_tuple
(
I0
),
priorDstValue_buf
);
priorDstValue_buf
);
accu
Value_buf
(
I0
)
+=
type_convert
<
compType
>
{}(
priorDstValue_buf
[
I0
]
*
beta
)
;
dst
Value_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
}
auto
threadwise_dst_val_store
=
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
comp
Type
,
ThreadwiseTensorSliceTransfer_v1r3
<
dstData
Type
,
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
ReducedDataDesc
),
dst1dDescType
,
dst1dDescType
,
...
@@ -602,7 +614,7 @@ struct GridwiseReduction_xy_to_x_blockwise
...
@@ -602,7 +614,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index
(
block_global_1d_id
));
make_multi_index
(
block_global_1d_id
));
threadwise_dst_val_store
.
Run
(
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accu
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
dst
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
View file @
eac1753d
...
@@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
...
@@ -166,11 +170,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -166,11 +170,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load
.
Run
(
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dst1dDesc
,
dst_global_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
accu
Value_buf
(
I0
)
+=
type_convert
<
compType
>
{}(
priorDstValue_buf
[
I0
]
*
beta
)
;
dst
Value_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
}
auto
threadwise_dst_store
=
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
comp
Type
,
ThreadwiseTensorSliceTransfer_v1r3
<
dstData
Type
,
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
ReducedDataDesc
),
dst1dDescType
,
dst1dDescType
,
...
@@ -184,7 +188,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -184,7 +188,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index
(
thread_global_1d_id
));
make_multi_index
(
thread_global_1d_id
));
threadwise_dst_store
.
Run
(
threadwise_dst_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accu
Value_buf
,
dst1dDesc
,
dst_global_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
dst
Value_buf
,
dst1dDesc
,
dst_global_buf
);
};
};
template
<
>
template
<
>
...
@@ -271,6 +275,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -271,6 +275,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
...
@@ -290,11 +298,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -290,11 +298,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load
.
Run
(
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
accu
Value_buf
(
I0
)
+=
type_convert
<
compType
>
{}(
priorDstValue_buf
[
I0
]
*
beta
)
;
dst
Value_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
}
auto
threadwise_dst_val_store
=
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
comp
Type
,
ThreadwiseTensorSliceTransfer_v1r3
<
dstData
Type
,
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
ReducedDataDesc
),
dst1dDescType
,
dst1dDescType
,
...
@@ -322,7 +330,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -322,7 +330,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index
(
thread_global_1d_id
));
make_multi_index
(
thread_global_1d_id
));
threadwise_dst_val_store
.
Run
(
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accu
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
dst
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
};
};
...
@@ -430,6 +438,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -430,6 +438,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
...
@@ -449,11 +461,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -449,11 +461,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load
.
Run
(
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
accu
Value_buf
(
I0
)
+=
type_convert
<
compType
>
{}(
priorDstValue_buf
[
I0
]
*
beta
)
;
dst
Value_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
}
auto
threadwise_dst_val_store
=
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
comp
Type
,
ThreadwiseTensorSliceTransfer_v1r3
<
dstData
Type
,
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
ReducedDataDesc
),
dst1dDescType
,
dst1dDescType
,
...
@@ -481,7 +493,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
...
@@ -481,7 +493,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index
(
thread_global_1d_id
));
make_multi_index
(
thread_global_1d_id
));
threadwise_dst_val_store
.
Run
(
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accu
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
dst
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
View file @
eac1753d
...
@@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
auto
threadwise_dst_load
=
auto
threadwise_dst_load
=
...
@@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
threadwise_dst_load
.
Run
(
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dst1dDesc
,
dst_global_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
accu
Value_buf
(
I0
)
+=
type_convert
<
compType
>
{}(
priorDstValue_buf
(
I0
)
*
beta
)
;
dst
Value_buf
(
I0
)
+=
priorDstValue_buf
(
I0
)
*
beta
;
}
}
auto
threadwise_dst_store
=
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
comp
Type
,
ThreadwiseTensorSliceTransfer_v1r3
<
dstData
Type
,
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
ReducedDataDesc
),
dst1dDescType
,
dst1dDescType
,
...
@@ -194,7 +198,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -194,7 +198,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index
(
warp_global_1d_id
));
make_multi_index
(
warp_global_1d_id
));
threadwise_dst_store
.
Run
(
threadwise_dst_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accu
Value_buf
,
dst1dDesc
,
dst_global_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
dst
Value_buf
,
dst1dDesc
,
dst_global_buf
);
}
}
};
};
...
@@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
auto
threadwise_dst_load
=
auto
threadwise_dst_load
=
...
@@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple
(
I0
),
make_tuple
(
I0
),
priorDstValue_buf
);
priorDstValue_buf
);
accu
Value_buf
(
I0
)
+=
type_convert
<
compType
>
{}(
priorDstValue_buf
[
I0
]
*
beta
)
;
dst
Value_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
}
auto
threadwise_dst_val_store
=
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
comp
Type
,
ThreadwiseTensorSliceTransfer_v1r3
<
dstData
Type
,
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
ReducedDataDesc
),
dst1dDescType
,
dst1dDescType
,
...
@@ -346,7 +354,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -346,7 +354,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index
(
warp_global_1d_id
));
make_multi_index
(
warp_global_1d_id
));
threadwise_dst_val_store
.
Run
(
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accu
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
dst
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
}
}
...
@@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if
(
!
float_equal_one
{}(
alpha
))
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
auto
threadwise_dst_load
=
auto
threadwise_dst_load
=
...
@@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple
(
I0
),
make_tuple
(
I0
),
priorDstValue_buf
);
priorDstValue_buf
);
accu
Value_buf
(
I0
)
+=
type_convert
<
compType
>
{}(
priorDstValue_buf
[
I0
]
*
beta
)
;
dst
Value_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
}
auto
threadwise_dst_val_store
=
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
comp
Type
,
ThreadwiseTensorSliceTransfer_v1r3
<
dstData
Type
,
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
ReducedDataDesc
),
dst1dDescType
,
dst1dDescType
,
...
@@ -521,7 +533,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
...
@@ -521,7 +533,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index
(
warp_global_1d_id
));
make_multi_index
(
warp_global_1d_id
));
threadwise_dst_val_store
.
Run
(
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accu
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
dst
Value_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment