Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a26c802d
"...composable_kernel_rocm.git" did not exist on "17ed368f5882dc71f70511bef86ce0831fd12f4d"
Commit
a26c802d
authored
May 26, 2022
by
wangshaojie6
Browse files
add some code for bfp16
parent
2927524e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
31 deletions
+75
-31
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
+3
-3
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+72
-28
No files found.
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
View file @
a26c802d
...
@@ -18,9 +18,9 @@
...
@@ -18,9 +18,9 @@
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "reference_conv_backward_weight.hpp"
#include "reference_conv_backward_weight.hpp"
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
b
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
b
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
b
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
...
...
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
a26c802d
...
@@ -963,37 +963,81 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -963,37 +963,81 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
{
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
if
(
kbatch
==
1
)
GridwiseGemm
,
{
ADataType
,
// TODO: distiguish A/B datatype
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
CDataType
,
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
ADataType
,
// TODO: distiguish A/B datatype
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CDataType
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
InElementwiseOperation
,
remove_reference_t
<
WeiElementwiseOperation
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
OutElementwiseOperation
,
true
>
;
InElementwiseOperation
,
WeiElementwiseOperation
,
Run
(
kernel
);
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAddFloatForBf16
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
if
(
kbatch
==
1
)
GridwiseGemm
,
{
ADataType
,
// TODO: distiguish A/B datatype
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
CDataType
,
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
ADataType
,
// TODO: distiguish A/B datatype
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CDataType
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
InElementwiseOperation
,
remove_reference_t
<
WeiElementwiseOperation
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
OutElementwiseOperation
,
false
>
;
InElementwiseOperation
,
WeiElementwiseOperation
,
Run
(
kernel
);
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAddFloatForBf16
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
}
}
}
}
else
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment