Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
4634b120
Unverified
Commit
4634b120
authored
Jun 22, 2022
by
Shaojie WANG
Committed by
GitHub
Jun 21, 2022
Browse files
fix Issue 291 (#294)
* rename for typeconvert functor * refine code
parent
15c89e81
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
128 deletions
+71
-128
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+71
-128
No files found.
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
4634b120
...
...
@@ -433,7 +433,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using
namespace
ck
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
2
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
...
...
@@ -671,11 +671,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return
PadDescriptor_M0_1d
(
desc
,
gridSize
,
blockSize
);
}
using
TypeConvertFunctor
=
using
TypeConvertF
p32ToBf16F
unctor
=
ck
::
tensor_operation
::
element_wise
::
UnaryTypeConvert
<
ck
::
bhalf_t
,
float
>
;
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
<
1
>
({
1
},
{
1
},
1
,
1
));
using
GridwiseUEltwise
=
GridwiseUnaryElementwise_1D
<
AccDataType
,
InDataType
,
GridDesc_M0
,
TypeConvertFunctor
,
4
>
;
using
GridwiseUEltwise
=
GridwiseUnaryElementwise_1D
<
AccDataType
,
InDataType
,
GridDesc_M0
,
TypeConvertFp32ToBf16Functor
,
4
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
...
...
@@ -979,19 +982,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
run_conv
=
[
&
](
const
auto
&
kernel
)
{
hipGetErrorString
(
hipMemset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
...
...
@@ -1016,8 +1018,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
AccDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
...
...
@@ -1059,7 +1060,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
// run kernel for type conversion
void
*
p_c_grid_tmp_
=
static_cast
<
void
*>
(
arg
.
p_c_grid_
);
InDataType
*
p_c_grid_tmp_bf16_
=
static_cast
<
InDataType
*>
(
p_c_grid_tmp_
);
const
auto
R
un_type_convert
=
[
&
](
const
auto
&
kernel
)
{
const
auto
r
un_type_convert
=
[
&
](
const
auto
&
kernel
)
{
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -1070,14 +1071,15 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
p_c_grid_tmp_bf16_
,
a_grid_desc_m0_
,
b_grid_desc_m0_
,
TypeConvertFunctor
{});
TypeConvertF
p32ToBf16F
unctor
{});
return
elapsed_time
;
};
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
if
(
has_main_k0_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
if
(
kbatch
==
1
)
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
...
...
@@ -1092,9 +1094,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
has_main_loop
>
;
Run
(
kernel
);
return
run_conv
(
kernel
);
}
else
{
...
...
@@ -1103,7 +1105,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
AccDataType
,
InDataType
,
GridDesc_M0
,
TypeConvertFunctor
>
;
TypeConvertF
p32ToBf16F
unctor
>
;
const
auto
kernel_conv
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAddFloatBf16Splitk
,
...
...
@@ -1117,56 +1119,28 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
has_main_loop
>
;
run_bf16_splitk
(
kernel_conv
);
ave_time
+=
Run_type_convert
(
kernel_type_convert
);
}
float
elapsed_time
=
0
;
elapsed_time
+=
run_bf16_splitk
(
kernel_conv
);
elapsed_time
+=
run_type_convert
(
kernel_type_convert
);
return
elapsed_time
;
}
else
{
if
(
kbatch
==
1
)
};
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAddFloatBf16Splitk
,
ADataType
,
// TODO: distiguish A/B datatype
AccDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
run_bf16_splitk
(
kernel
);
}
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
else
{
if
(
has_main_k0_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
if
(
kbatch
==
1
)
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
...
...
@@ -1181,9 +1155,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
has_main_loop
>
;
Run
(
kernel
);
return
run_conv
(
kernel
);
}
else
{
...
...
@@ -1199,49 +1173,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
has_main_loop
>
;
Run
(
kernel
);
return
run_conv
(
kernel
);
}
}
else
{
if
(
kbatch
==
1
)
};
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment