Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
55159365
Commit
55159365
authored
Nov 16, 2021
by
Chao Liu
Browse files
refactor type_convert
parent
d8a632a8
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
116 additions
and
138 deletions
+116
-138
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
...sor_operation/gridwise_generic_2d_reduction_blockwise.hpp
+10
-12
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
...ation/gridwise_generic_2d_reduction_direct_threadwise.hpp
+10
-12
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
...eration/gridwise_generic_2d_reduction_direct_warpwise.hpp
+10
-12
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
...or_operation/gridwise_generic_2d_reduction_multiblock.hpp
+2
-2
composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp
...nclude/tensor_operation/reduction_functions_blockwise.hpp
+17
-17
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+3
-3
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
.../tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
+2
-2
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
+2
-2
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+29
-36
composable_kernel/include/utility/inner_product.hpp
composable_kernel/include/utility/inner_product.hpp
+4
-12
composable_kernel/include/utility/reduction_operator.hpp
composable_kernel/include/utility/reduction_operator.hpp
+4
-4
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+6
-6
host/host_tensor/include/host_tensor_generator.hpp
host/host_tensor/include/host_tensor_generator.hpp
+17
-18
No files found.
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
View file @
55159365
...
...
@@ -95,7 +95,7 @@ struct GridwiseReduction_xy_to_x_blockwise
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}
(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
...
...
@@ -178,11 +178,11 @@ struct GridwiseReduction_xy_to_x_blockwise
if
(
thread_local_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}
(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}
(
accuValue_buf
[
I0
]);
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
...
...
@@ -246,7 +246,7 @@ struct GridwiseReduction_xy_to_x_blockwise
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}
(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
dst_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
...
@@ -347,11 +347,11 @@ struct GridwiseReduction_xy_to_x_blockwise
if
(
thread_local_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}
(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}
(
accuValue_buf
[
I0
]);
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
...
...
@@ -433,10 +433,8 @@ struct GridwiseReduction_xy_to_x_blockwise
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
const
auto
src_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_indices_global
,
src2dDesc
.
GetElementSpaceSize
());
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
...
@@ -553,11 +551,11 @@ struct GridwiseReduction_xy_to_x_blockwise
if
(
thread_local_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}
(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}
(
accuValue_buf
[
I0
]);
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
View file @
55159365
...
...
@@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}
(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
...
...
@@ -145,11 +145,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}
(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}
(
accuValue_buf
[
I0
]);
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
...
...
@@ -207,7 +207,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}
(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
dst_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
...
@@ -273,11 +273,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}
(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}
(
accuValue_buf
[
I0
]);
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
...
...
@@ -350,10 +350,8 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
const
auto
src_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_indices_global
,
src2dDesc
.
GetElementSpaceSize
());
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
...
@@ -436,11 +434,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}
(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}
(
accuValue_buf
[
I0
]);
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
View file @
55159365
...
...
@@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}
(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
...
...
@@ -154,11 +154,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if
(
thread_inwarp_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}
(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}
(
accuValue_buf
[
I0
]);
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
...
...
@@ -218,7 +218,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}
(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
dst_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
...
@@ -293,11 +293,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if
(
thread_inwarp_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}
(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}
(
accuValue_buf
[
I0
]);
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
...
...
@@ -375,10 +375,8 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
const
auto
src_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_indices_global
,
src2dDesc
.
GetElementSpaceSize
());
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
...
@@ -472,11 +470,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if
(
thread_inwarp_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}
(
alpha
);
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}
(
accuValue_buf
[
I0
]);
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
...
...
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
View file @
55159365
...
...
@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_multiblock
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}
(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
auto
workspace_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
);
...
...
@@ -223,7 +223,7 @@ struct GridwiseReduction_xy_to_x_multiblock
__shared__
int
p_in_block_indices_buffer
[
BlockBufferSize
];
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}
(
zeroVal
));
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
(
zeroVal
));
auto
workspace_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
);
auto
workspace_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
...
composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp
View file @
55159365
...
...
@@ -64,7 +64,7 @@ struct BlockwiseReduction_2d_block_buffer
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_local_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
otherDimInd
));
compType
opData
=
type_convert
<
compType
>
{}
(
block_buffer
[
offset
]);
compType
opData
=
type_convert
<
compType
>
(
block_buffer
[
offset
]);
binop
::
calculate
(
lAccuData
,
opData
);
}
...
...
@@ -89,10 +89,10 @@ struct BlockwiseReduction_2d_block_buffer
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
0
,
thread_local_id
+
indOffset
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
+
indOffset
,
0
));
compType
opData1
=
type_convert
<
compType
>
{}
(
block_buffer
[
offset1
]);
compType
opData2
=
type_convert
<
compType
>
{}
(
block_buffer
[
offset2
]);
compType
opData1
=
type_convert
<
compType
>
(
block_buffer
[
offset1
]);
compType
opData2
=
type_convert
<
compType
>
(
block_buffer
[
offset2
]);
binop
::
calculate
(
opData1
,
opData2
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
{}
(
opData1
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
(
opData1
);
}
__syncthreads
();
...
...
@@ -100,7 +100,7 @@ struct BlockwiseReduction_2d_block_buffer
if
(
thread_local_id
==
0
)
{
compType
tmpVal
=
type_convert
<
compType
>
{}
(
block_buffer
[
0
]);
compType
tmpVal
=
type_convert
<
compType
>
(
block_buffer
[
0
]);
binop
::
calculate
(
accuData
,
tmpVal
);
}
...
...
@@ -131,13 +131,13 @@ struct BlockwiseReduction_2d_block_buffer
index_t
offset2
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_local_id
+
indOffset
));
compType
currVal1
=
type_convert
<
compType
>
{}
(
block_buffer
[
offset1
]);
compType
currVal2
=
type_convert
<
compType
>
{}
(
block_buffer
[
offset2
]);
compType
currVal1
=
type_convert
<
compType
>
(
block_buffer
[
offset1
]);
compType
currVal2
=
type_convert
<
compType
>
(
block_buffer
[
offset2
]);
int
currIndex1
=
block_indices_buffer
[
offset1
];
int
currIndex2
=
block_indices_buffer
[
offset2
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
{}
(
currVal1
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
(
currVal1
);
block_indices_buffer
(
offset1
)
=
currIndex1
;
}
__syncthreads
();
...
...
@@ -150,7 +150,7 @@ struct BlockwiseReduction_2d_block_buffer
{
index_t
offset
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
0
));
compType
tmpVal
=
type_convert
<
compType
>
{}
(
block_buffer
[
offset
]);
compType
tmpVal
=
type_convert
<
compType
>
(
block_buffer
[
offset
]);
int
tmpIndex
=
block_indices_buffer
[
offset
];
binop
::
calculate
(
lAccuData
,
tmpVal
,
lAccuIndex
,
tmpIndex
);
...
...
@@ -166,7 +166,7 @@ struct BlockwiseReduction_2d_block_buffer
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
toReduceBlocks
;
otherDimInd
++
)
{
offset
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
otherDimInd
));
compType
currVal
=
type_convert
<
compType
>
{}
(
block_buffer
[
offset
]);
compType
currVal
=
type_convert
<
compType
>
(
block_buffer
[
offset
]);
int
currIndex
=
block_indices_buffer
[
offset
];
binop
::
calculate
(
lAccuData
,
currVal
,
lAccuIndex
,
currIndex
);
...
...
@@ -187,13 +187,13 @@ struct BlockwiseReduction_2d_block_buffer
index_t
offset2
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
+
indOffset
,
0
));
compType
currVal1
=
type_convert
<
compType
>
{}
(
block_buffer
[
offset1
]);
compType
currVal2
=
type_convert
<
compType
>
{}
(
block_buffer
[
offset2
]);
compType
currVal1
=
type_convert
<
compType
>
(
block_buffer
[
offset1
]);
compType
currVal2
=
type_convert
<
compType
>
(
block_buffer
[
offset2
]);
int
currIndex1
=
block_indices_buffer
[
offset1
];
int
currIndex2
=
block_indices_buffer
[
offset2
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
{}
(
currVal1
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
(
currVal1
);
block_indices_buffer
(
offset1
)
=
currIndex1
;
}
...
...
@@ -202,7 +202,7 @@ struct BlockwiseReduction_2d_block_buffer
if
(
thread_local_id
==
0
)
{
compType
tmpVal
=
type_convert
<
compType
>
{}
(
block_buffer
[
0
]);
compType
tmpVal
=
type_convert
<
compType
>
(
block_buffer
[
0
]);
int
tmpIndex
=
block_indices_buffer
[
0
];
binop
::
calculate
(
accuData
,
tmpVal
,
accuIndex
,
tmpIndex
);
...
...
@@ -227,9 +227,9 @@ struct BlockwiseReduction_2d_block_buffer
}
};
// Initialize the block-wise indices buffer, the index for each element in the block-wise
data
// buffer
//
is calculated according to its position in the buffer and the global starting
index
// Initialize the block-wise indices buffer, the index for each element in the block-wise
//
data
buffer
is calculated according to its position in the buffer and the global starting
// index
template
<
typename
IdxBufferType
>
__device__
static
void
init_buffer_indices
(
IdxBufferType
&
block_indices_buffer
,
int
indexStart
)
{
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
55159365
...
...
@@ -196,7 +196,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
src_slice_origin_idx
+
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
{}
(
src_buf
[
Number
<
src_offset
>
{}]);
type_convert
<
DstData
>
(
src_buf
[
Number
<
src_offset
>
{}]);
});
const
bool
is_dst_valid
=
...
...
@@ -983,7 +983,7 @@ struct ThreadwiseTensorSliceTransfer_v3
buffer_desc_
.
CalculateOffset
(
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
{}
(
buffer_
[
Number
<
buffer_offset
>
{}]);
type_convert
<
DstData
>
(
buffer_
[
Number
<
buffer_offset
>
{}]);
});
using
dst_vector_t
=
typename
decltype
(
dst_tmp_vector
)
::
type
;
...
...
@@ -1403,7 +1403,7 @@ struct ThreadwiseTensorSliceTransfer_v4
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
{}
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
View file @
55159365
...
...
@@ -351,7 +351,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_vector_desc
.
CalculateOffset
(
dst_vector_idx
);
dst_vector
.
template
AsType
<
DstData
>()(
Number
<
dst_vector_offset
>
{})
=
type_convert
<
DstData
>
{}
(
buffer_
[
Number
<
buffer_offset
>
{}]);
type_convert
<
DstData
>
(
buffer_
[
Number
<
buffer_offset
>
{}]);
});
using
dst_vector_t
=
typename
decltype
(
dst_vector
)
::
type
;
...
...
@@ -750,7 +750,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
src_vector_idx
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
{}
(
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
DstData
>()[
Number
<
src_vector_offset
>
{}]);
});
});
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
View file @
55159365
...
...
@@ -248,7 +248,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
{}
(
src_thread_scratch_
[
idx
]);
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
(
src_thread_scratch_
[
idx
]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
...
...
@@ -322,7 +322,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
{}
(
src_thread_scratch_
[
idx
]);
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
(
src_thread_scratch_
[
idx
]);
});
}
#endif
...
...
composable_kernel/include/utility/data_type.hpp
View file @
55159365
...
...
@@ -927,23 +927,36 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
static
__host__
__device__
float
bf16_to_f32
(
ushort
src_val
)
// Convert X to Y
template
<
typename
Y
,
typename
X
>
__host__
__device__
Y
type_convert
(
X
x
)
{
return
static_cast
<
Y
>
(
x
);
}
// convert bfp16 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
(
ushort
x
)
{
union
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
uint32_t
(
src_val
)
<<
16
};
}
u
=
{
uint32_t
(
x
)
<<
16
};
return
u
.
fp32
;
}
static
__host__
__device__
ushort
f32_to_bf16
(
float
src_val
)
// convert fp32 to bfp16
template
<
>
inline
__host__
__device__
ushort
type_convert
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
src_val
};
}
u
=
{
x
};
if
(
~
u
.
int32
&
0x7f800000
)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
...
...
@@ -976,40 +989,14 @@ static __host__ __device__ ushort f32_to_bf16(float src_val)
// the bloat16's mantissa bits are all 0.
u
.
int32
|=
0x10000
;
// Preserve signaling NaN
}
return
uint16_t
(
u
.
int32
>>
16
);
}
// data type conversion
template
<
typename
T
>
struct
type_convert
{
template
<
typename
X
>
__device__
T
operator
()(
X
x
)
const
{
return
static_cast
<
T
>
(
x
);
}
};
template
<
>
template
<
>
__device__
float
type_convert
<
float
>::
operator
()
<
ushort
>
(
ushort
x
)
const
{
return
bf16_to_f32
(
x
);
}
template
<
>
template
<
>
__device__
ushort
type_convert
<
ushort
>::
operator
()
<
float
>
(
float
x
)
const
{
return
f32_to_bf16
(
x
);
return
uint16_t
(
u
.
int32
>>
16
);
}
// TODO: deprecate this
template
<
typename
T
>
struct
inner_product_with_conversion
{
static
constexpr
auto
convert
=
type_convert
<
T
>
();
template
<
typename
X
,
index_t
N
>
__device__
T
operator
()(
typename
vector_type
<
X
,
N
>::
type
a
,
typename
vector_type
<
X
,
N
>::
type
b
)
const
...
...
@@ -1020,13 +1007,16 @@ struct inner_product_with_conversion
T
acc
=
0
;
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
convert
(
a_vector
.
Scalars
()[
i
])
*
convert
(
b_vector
.
Scalars
()[
i
]);
acc
+=
type_
convert
<
T
>
(
a_vector
.
Scalars
()[
i
])
*
type_
convert
<
T
>
(
b_vector
.
Scalars
()[
i
]);
});
return
acc
;
}
__device__
T
operator
()(
float_t
a
,
float_t
b
)
const
{
return
convert
(
a
)
*
convert
(
b
);
}
__device__
T
operator
()(
float_t
a
,
float_t
b
)
const
{
return
type_convert
<
T
>
(
a
)
*
type_convert
<
T
>
(
b
);
}
__device__
T
operator
()(
int8x4_t
a
,
int8x4_t
b
)
const
{
...
...
@@ -1036,7 +1026,8 @@ struct inner_product_with_conversion
T
acc
=
0
;
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
convert
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
convert
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
acc
+=
type_convert
<
T
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
return
acc
;
...
...
@@ -1050,7 +1041,8 @@ struct inner_product_with_conversion
T
acc
=
0
;
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
convert
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
convert
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
acc
+=
type_convert
<
T
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
return
acc
;
...
...
@@ -1064,7 +1056,8 @@ struct inner_product_with_conversion
T
acc
=
0
;
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
convert
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
convert
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
acc
+=
type_convert
<
T
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
return
acc
;
...
...
composable_kernel/include/utility/inner_product.hpp
View file @
55159365
...
...
@@ -28,12 +28,6 @@ __device__ void inner_product<float, float, float>(const float& a, const float&
#endif
}
template
<
>
__device__
void
inner_product
<
ushort
,
ushort
,
float
>
(
const
ushort
&
a
,
const
ushort
&
b
,
float
&
c
)
{
c
+=
bf16_to_f32
(
a
)
*
bf16_to_f32
(
b
);
}
template
<
>
__device__
void
inner_product
<
float2_t
,
float2_t
,
float
>
(
const
float2_t
&
a
,
const
float2_t
&
b
,
float
&
c
)
...
...
@@ -90,13 +84,12 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
c
=
__builtin_amdgcn_sdot2
(
a
,
b
,
c
,
false
);
#endif
#else
const
auto
convert
=
type_convert
<
int32_t
>
{};
const
vector_type
<
half_t
,
2
>
a_vector
{
a
};
const
vector_type
<
half_t
,
2
>
b_vector
{
b
};
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
c
+=
convert
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
convert
(
b_vector
.
AsType
<
half_t
>
()[
i
]);
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
type_convert
<
int32_t
>
(
b_vector
.
AsType
<
half_t
>
()[
i
]);
});
#endif
}
...
...
@@ -156,13 +149,12 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
c
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b
),
c
,
false
);
#endif
#else
const
auto
convert
=
type_convert
<
int32_t
>
{};
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
c
+=
convert
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
convert
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
int32_t
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
#endif
}
...
...
composable_kernel/include/utility/reduction_operator.hpp
View file @
55159365
...
...
@@ -165,7 +165,7 @@ struct unary_identic
scaler
=
1.0
f
/
static_cast
<
float
>
(
divider
);
};
__device__
inline
constexpr
T
operator
()(
T
a
)
const
{
return
a
*
type_convert
<
T
>
{}
(
scaler
);
};
__device__
inline
constexpr
T
operator
()(
T
a
)
const
{
return
a
*
type_convert
<
T
>
(
scaler
);
};
float
scaler
=
1.0
f
;
};
...
...
@@ -187,7 +187,7 @@ struct unary_square
{
a
=
a
*
a
;
return
a
*
type_convert
<
T
>
{}
(
scaler
);
return
a
*
type_convert
<
T
>
(
scaler
);
};
float
scaler
=
1.0
f
;
...
...
@@ -210,7 +210,7 @@ struct unary_abs
{
a
=
abs
(
a
);
return
a
*
type_convert
<
T
>
{}
(
scaler
);
return
a
*
type_convert
<
T
>
(
scaler
);
};
float
scaler
=
1.0
f
;
...
...
@@ -249,7 +249,7 @@ struct unary_abs<half_t, hasDividing>
{
a
=
static_cast
<
half_t
>
(
__habs
(
a
));
return
a
*
type_convert
<
half_t
>
{}
(
scaler
);
return
a
*
type_convert
<
half_t
>
(
scaler
);
};
float
scaler
=
1.0
f
;
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
55159365
...
...
@@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
if
constexpr
(
is_same
<
TIn
,
ushort
>::
value
)
{
v
+=
ck
::
bf16_to_f32
(
in
(
n
,
c
,
hi
,
wi
))
*
ck
::
bf16_to_f32
(
wei
(
k
,
c
,
y
,
x
));
v
+=
ck
::
type_convert
<
float
>
(
in
(
n
,
c
,
hi
,
wi
))
*
ck
::
type_convert
<
float
>
(
wei
(
k
,
c
,
y
,
x
));
}
else
{
...
...
@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if
constexpr
(
is_same
<
TOut
,
ushort
>::
value
)
{
out
(
n
,
k
,
ho
,
wo
)
=
f32_to_bf16
(
v
);
out
(
n
,
k
,
ho
,
wo
)
=
type_convert
<
ushort
>
(
v
);
}
else
{
...
...
@@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
if
constexpr
(
is_same
<
TIn
,
ushort
>::
value
)
{
v
+=
ck
::
bf16_to_f32
(
in
(
n
,
hi
,
wi
,
c
))
*
ck
::
bf16_to_f32
(
wei
(
k
,
y
,
x
,
c
));
v
+=
ck
::
type_convert
<
float
>
(
in
(
n
,
hi
,
wi
,
c
))
*
ck
::
type_convert
<
float
>
(
wei
(
k
,
y
,
x
,
c
));
}
else
{
...
...
@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
}
if
constexpr
(
is_same
<
TOut
,
ushort
>::
value
)
{
out
(
n
,
ho
,
wo
,
k
)
=
f32_to_bf16
(
v
);
out
(
n
,
ho
,
wo
,
k
)
=
ck
::
type_convert
<
ushort
>
(
v
);
}
else
{
...
...
host/host_tensor/include/host_tensor_generator.hpp
View file @
55159365
...
...
@@ -5,15 +5,25 @@
#include "config.hpp"
#include "data_type.hpp"
template
<
typename
T
>
struct
GeneratorTensor_0
{
template
<
typename
...
Is
>
T
operator
()(
Is
...)
{
return
T
{
0
};
}
};
template
<
typename
T
>
struct
GeneratorTensor_1
{
int
value
=
1
;
template
<
typename
...
Is
>
float
operator
()(
Is
...)
T
operator
()(
Is
...)
{
return
value
;
return
ck
::
type_convert
<
T
>
(
value
)
;
}
};
...
...
@@ -25,7 +35,7 @@ struct GeneratorTensor_1<ushort>
template
<
typename
...
Is
>
ushort
operator
()(
Is
...)
{
return
ck
::
f32_to_bf16
(
value
);
return
ck
::
type_convert
<
ushort
>
(
value
);
}
};
...
...
@@ -41,17 +51,6 @@ struct GeneratorTensor_1<int8_t>
}
};
struct
GeneratorTensor_0
{
int
value
=
0
;
template
<
typename
...
Is
>
float
operator
()(
Is
...)
{
return
value
;
}
};
template
<
typename
T
>
struct
GeneratorTensor_2
{
...
...
@@ -59,7 +58,7 @@ struct GeneratorTensor_2
int
max_value
=
1
;
template
<
typename
...
Is
>
float
operator
()(
Is
...)
T
operator
()(
Is
...)
{
return
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
}
...
...
@@ -75,7 +74,7 @@ struct GeneratorTensor_2<ushort>
ushort
operator
()(
Is
...)
{
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
return
ck
::
f32_to_bf16
(
tmp
);
return
ck
::
type_convert
<
ushort
>
(
tmp
);
}
};
...
...
@@ -99,7 +98,7 @@ struct GeneratorTensor_3
T
max_value
=
1
;
template
<
typename
...
Is
>
float
operator
()(
Is
...)
T
operator
()(
Is
...)
{
float
tmp
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
...
...
@@ -120,7 +119,7 @@ struct GeneratorTensor_3<ushort>
float
fp32_tmp
=
min_value
+
tmp
*
(
max_value
-
min_value
);
return
ck
::
f32_to_bf16
(
fp32_tmp
);
return
ck
::
type_convert
<
ushort
>
(
fp32_tmp
);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment