Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
baf405cd
Commit
baf405cd
authored
Mar 02, 2022
by
ltqin
Browse files
simple transform result correct
parent
df22ba01
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
34 deletions
+48
-34
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+48
-32
example/14_conv2d_backward_weight_xdl/main.cpp
example/14_conv2d_backward_weight_xdl/main.cpp
+0
-2
No files found.
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
baf405cd
...
@@ -370,21 +370,9 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -370,21 +370,9 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k0_block_loop
)
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
nrepeat
>
0
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
ave_time
=
ave_time
=
launch_and_time_kernel
(
kernel
,
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -402,24 +390,16 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -402,24 +390,16 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
else
if
(
kbatch
>
1
||
nrepeat
<=
0
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
hipGetErrorString
(
hipMemset
(
GridwiseGemm
,
arg
.
p_c_grid_
,
ADataType
,
// TODO: distiguish A/B datatype
0
,
CDataType
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
sizeof
(
CDataType
)));
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
ave_time
=
launch_kernel
(
kernel
,
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -434,6 +414,42 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -434,6 +414,42 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
};
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
return
ave_time
;
return
ave_time
;
}
}
...
...
example/14_conv2d_backward_weight_xdl/main.cpp
View file @
baf405cd
...
@@ -203,8 +203,6 @@ int main(int argc, char* argv[])
...
@@ -203,8 +203,6 @@ int main(int argc, char* argv[])
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_device(before): "
,
wei_k_c_y_x_device_result
.
mData
,
","
)
<<
std
::
endl
;
// do GEMM
// do GEMM
auto
conv
=
DeviceConvWrWInstance
{};
auto
conv
=
DeviceConvWrWInstance
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment