Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
32d06c66
Commit
32d06c66
authored
May 30, 2022
by
wangshaojie6
Browse files
support bf16 splitk kernel for convnd bwd weight
parent
bccc6d8b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
134 additions
and
31 deletions
+134
-31
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp
...nvnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp
+3
-19
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+2
-0
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+129
-3
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+0
-9
No files found.
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp
View file @
32d06c66
...
@@ -329,24 +329,7 @@ int main(int argc, char* argv[])
...
@@ -329,24 +329,7 @@ int main(int argc, char* argv[])
DeviceMem
wei_work_space_device_buf
(
bwd_weight_workspace_size
);
DeviceMem
wei_work_space_device_buf
(
bwd_weight_workspace_size
);
wei_work_space_device_buf
.
SetZero
();
wei_work_space_device_buf
.
SetZero
();
argument
=
conv
->
MakeArgumentPointer
(
conv
->
SetWorkSpacePointer
(
argument
.
get
(),
wei_work_space_device_buf
.
GetDeviceBuffer
());
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
AccDataType
*>
(
wei_work_space_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
params
.
N_
,
params
.
K_
,
params
.
C_
,
params
.
input_spatial_lengths_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
,
params
.
conv_filter_strides_
,
params
.
conv_filter_dilations_
,
params
.
input_left_pads_
,
params
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{},
split_k
);
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
{
{
...
@@ -358,6 +341,7 @@ int main(int argc, char* argv[])
...
@@ -358,6 +341,7 @@ int main(int argc, char* argv[])
conv_ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
conv_ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
#if 0
// do type convert
// do type convert
auto type_convert = DeviceUnaryElementwiseTypeConvertInstance{};
auto type_convert = DeviceUnaryElementwiseTypeConvertInstance{};
auto type_convert_invoker = type_convert.MakeInvokerPointer();
auto type_convert_invoker = type_convert.MakeInvokerPointer();
...
@@ -381,7 +365,7 @@ int main(int argc, char* argv[])
...
@@ -381,7 +365,7 @@ int main(int argc, char* argv[])
type_convert_ave_time =
type_convert_ave_time =
type_convert_invoker->Run(type_convert_argument.get(), StreamConfig{nullptr, time_kernel});
type_convert_invoker->Run(type_convert_argument.get(), StreamConfig{nullptr, time_kernel});
// type_convert_invoker->Run(type_convert_argument.get(), StreamConfig{nullptr, time_kernel});
// type_convert_invoker->Run(type_convert_argument.get(), StreamConfig{nullptr, time_kernel});
#endif
// host code to check if conv give me a right result
// host code to check if conv give me a right result
// Tensor<AccDataType> wei_k_c_y_x_device_result_fp32(
// Tensor<AccDataType> wei_k_c_y_x_device_result_fp32(
// ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
// ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
32d06c66
...
@@ -42,6 +42,8 @@ struct BaseOperator
...
@@ -42,6 +42,8 @@ struct BaseOperator
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
,
void
*
)
const
{}
virtual
~
BaseOperator
()
{}
virtual
~
BaseOperator
()
{}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
32d06c66
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
#include "gridwise_unary_elementwise_1d.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
1
);
1
);
}
}
// type convert descs
template
<
typename
Desc_M0
>
static
auto
PadDescriptor_M0_1d
(
Desc_M0
desc_m0
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
4
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m0_pad
;
}
template
<
index_t
Dim
>
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
Dim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
Dim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
Dim
>
1
)
{
const
auto
desc_m0
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
Dim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M0_1d
(
desc
,
gridSize
,
blockSize
);
}
using
TypeConvertFunctor
=
ck
::
tensor_operation
::
element_wise
::
UnaryTypeConvert
<
ck
::
bhalf_t
,
float
>
;
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
<
1
>
({
1
},
{
1
},
1
,
1
));
using
GridwiseUEltwise
=
GridwiseUnaryElementwise_1D
<
AccDataType
,
InDataType
,
GridDesc_M0
,
TypeConvertFunctor
,
4
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
...
@@ -851,6 +900,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -851,6 +900,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
c_grid_desc_m_n_
=
descs
[
I2
];
// init work space
p_c_workspace_grid_
=
nullptr
;
block_2_ctile_map_
=
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
...
@@ -887,6 +939,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -887,6 +939,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
index_t
k_batch_
;
index_t
k_batch_
;
// external work space
void
*
p_c_workspace_grid_
;
};
};
// Invoker
// Invoker
...
@@ -959,6 +1014,64 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -959,6 +1014,64 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
};
};
// run kernel for bf16 with splitk
const
auto
Run_bf16_splitk
=
[
&
](
const
auto
&
kernel
)
{
hipGetErrorString
(
hipMemset
(
arg
.
p_c_workspace_grid_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
AccDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
static_cast
<
AccDataType
*>
(
arg
.
p_c_workspace_grid_
),
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
};
// kernel for type conversion
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
arg
.
Conv_K_
),
static_cast
<
std
::
size_t
>
(
arg
.
Conv_C_
)};
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
std
::
begin
(
arg
.
filter_spatial_lengths_
),
std
::
end
(
arg
.
filter_spatial_lengths_
));
int
tensor_size
=
std
::
accumulate
(
filter_dims
.
begin
(),
filter_dims
.
end
(),
1
,
std
::
multiplies
<
int
>
{});
GridDesc_M0
a_grid_desc_m0_
=
MakeDescriptor_M0
<
1
>
({
tensor_size
},
{
1
},
240
,
256
);
GridDesc_M0
b_grid_desc_m0_
=
MakeDescriptor_M0
<
1
>
({
tensor_size
},
{
1
},
240
,
256
);
// run kernel for type conversion
void
*
p_c_grid_tmp_
=
static_cast
<
void
*>
(
arg
.
p_c_grid_
);
InDataType
*
p_c_grid_tmp_bf16_
=
static_cast
<
InDataType
*>
(
p_c_grid_tmp_
);
const
auto
Run_type_convert
=
[
&
](
const
auto
&
kernel
)
{
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
240
),
dim3
(
256
),
0
,
static_cast
<
AccDataType
*>
(
arg
.
p_c_workspace_grid_
),
p_c_grid_tmp_bf16_
,
a_grid_desc_m0_
,
b_grid_desc_m0_
,
TypeConvertFunctor
{});
return
elapsed_time
;
};
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
{
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
...
@@ -983,7 +1096,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -983,7 +1096,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
const
auto
kernel_type_convert
=
kernel_unary_elementwise_1d
<
GridwiseUEltwise
,
AccDataType
,
InDataType
,
GridDesc_M0
,
TypeConvertFunctor
>
;
const
auto
kernel_conv
=
kernel_gemm_xdlops_bwd_weight
<
GridwiseGemmAtomicAddFloatBf16Splitk
,
GridwiseGemmAtomicAddFloatBf16Splitk
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
AccDataType
,
AccDataType
,
...
@@ -997,7 +1117,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -997,7 +1117,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
true
>
;
Run
(
kernel
);
Run_bf16_splitk
(
kernel_conv
);
ave_time
+=
Run_type_convert
(
kernel_type_convert
);
}
}
}
}
else
else
...
@@ -1036,7 +1157,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -1036,7 +1157,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
false
>
;
Run
(
kernel
);
Run
_bf16_splitk
(
kernel
);
}
}
}
}
}
}
...
@@ -1319,6 +1440,11 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -1319,6 +1440,11 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
{
return
GetWorkSpaceSize
<
NumDimSpatial
>
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
GetWorkSpaceSize
<
NumDimSpatial
>
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
workspace_ptr
)
const
override
{
dynamic_cast
<
Argument
*>
(
p_arg
)
->
p_c_workspace_grid_
=
workspace_ptr
;
}
};
};
}
// namespace device
}
// namespace device
...
...
library/include/ck/library/host_tensor/device.hpp
View file @
32d06c66
...
@@ -111,15 +111,6 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -111,15 +111,6 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
}
}
else
else
{
{
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
return
0
;
return
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment