Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
63cdd923
Unverified
Commit
63cdd923
authored
Jun 18, 2022
by
Shaojie WANG
Committed by
GitHub
Jun 17, 2022
Browse files
use universal workspace pointer in bwd-weight (#286)
parent
c7a96ed5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
14 deletions
+3
-14
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+3
-14
No files found.
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
63cdd923
...
...
@@ -900,9 +900,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
// init work space
p_c_workspace_grid_
=
nullptr
;
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
...
...
@@ -939,9 +936,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
index_t
k_batch_
;
// external work space
void
*
p_c_workspace_grid_
;
};
// Invoker
...
...
@@ -1017,7 +1011,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
// run kernel for bf16 with splitk
const
auto
run_bf16_splitk
=
[
&
](
const
auto
&
kernel
)
{
hipGetErrorString
(
hipMemset
(
arg
.
p_
c_
workspace_
grid_
,
arg
.
p_workspace_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
AccDataType
)));
...
...
@@ -1030,7 +1024,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
static_cast
<
AccDataType
*>
(
arg
.
p_
c_
workspace_
grid_
),
static_cast
<
AccDataType
*>
(
arg
.
p_workspace_
),
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
...
@@ -1072,7 +1066,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
dim3
(
type_convert_grid_size
),
dim3
(
256
),
0
,
static_cast
<
AccDataType
*>
(
arg
.
p_
c_
workspace_
grid_
),
static_cast
<
AccDataType
*>
(
arg
.
p_workspace_
),
p_c_grid_tmp_bf16_
,
a_grid_desc_m0_
,
b_grid_desc_m0_
,
...
...
@@ -1448,11 +1442,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
return
GetWorkSpaceSize
<
NumDimSpatial
>
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
workspace_ptr
)
const
override
{
dynamic_cast
<
Argument
*>
(
p_arg
)
->
p_c_workspace_grid_
=
workspace_ptr
;
}
};
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment