"...text-generation-inference.git" did not exist on "94d243b3d7916879f1735d4a6f231e915765c1c4"
Commit 32d06c66 authored by wangshaojie6's avatar wangshaojie6
Browse files

support bf16 splitk kernel for convnd bwd weight

parent bccc6d8b
...@@ -329,24 +329,7 @@ int main(int argc, char* argv[]) ...@@ -329,24 +329,7 @@ int main(int argc, char* argv[])
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
wei_work_space_device_buf.SetZero(); wei_work_space_device_buf.SetZero();
argument = conv->MakeArgumentPointer( conv->SetWorkSpacePointer(argument.get(), wei_work_space_device_buf.GetDeviceBuffer());
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(wei_work_space_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N_,
params.K_,
params.C_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{},
split_k);
if(!conv->IsSupportedArgument(argument.get())) if(!conv->IsSupportedArgument(argument.get()))
{ {
...@@ -358,6 +341,7 @@ int main(int argc, char* argv[]) ...@@ -358,6 +341,7 @@ int main(int argc, char* argv[])
conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
#if 0
// do type convert // do type convert
auto type_convert = DeviceUnaryElementwiseTypeConvertInstance{}; auto type_convert = DeviceUnaryElementwiseTypeConvertInstance{};
auto type_convert_invoker = type_convert.MakeInvokerPointer(); auto type_convert_invoker = type_convert.MakeInvokerPointer();
...@@ -381,7 +365,7 @@ int main(int argc, char* argv[]) ...@@ -381,7 +365,7 @@ int main(int argc, char* argv[])
type_convert_ave_time = type_convert_ave_time =
type_convert_invoker->Run(type_convert_argument.get(), StreamConfig{nullptr, time_kernel}); type_convert_invoker->Run(type_convert_argument.get(), StreamConfig{nullptr, time_kernel});
// type_convert_invoker->Run(type_convert_argument.get(), StreamConfig{nullptr, time_kernel}); // type_convert_invoker->Run(type_convert_argument.get(), StreamConfig{nullptr, time_kernel});
#endif
// host code to check if conv give me a right result // host code to check if conv give me a right result
// Tensor<AccDataType> wei_k_c_y_x_device_result_fp32( // Tensor<AccDataType> wei_k_c_y_x_device_result_fp32(
// ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); // ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
......
...@@ -42,6 +42,8 @@ struct BaseOperator ...@@ -42,6 +42,8 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument*, void*) const {}
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp" #include "gridwise_gemm_xdlops_bwd_weight.hpp"
#include "gridwise_unary_elementwise_1d.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
1); 1);
} }
// type convert descs
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * 4;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
}
template <index_t Dim>
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1)
{
const auto desc_m0 = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
}
else
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
}
using TypeConvertFunctor =
ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
using GridwiseUEltwise =
GridwiseUnaryElementwise_1D<AccDataType, InDataType, GridDesc_M0, TypeConvertFunctor, 4>;
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
...@@ -851,6 +900,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -851,6 +900,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2]; c_grid_desc_m_n_ = descs[I2];
// init work space
p_c_workspace_grid_ = nullptr;
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
...@@ -887,6 +939,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -887,6 +939,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
std::vector<index_t> input_left_pads_; std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_; std::vector<index_t> input_right_pads_;
index_t k_batch_; index_t k_batch_;
// external work space
void* p_c_workspace_grid_;
}; };
// Invoker // Invoker
...@@ -959,6 +1014,64 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -959,6 +1014,64 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
}; };
// run kernel for bf16 with splitk
const auto Run_bf16_splitk = [&](const auto& kernel) {
hipGetErrorString(hipMemset(
arg.p_c_workspace_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(AccDataType)));
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
// kernel for type conversion
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(arg.Conv_K_),
static_cast<std::size_t>(arg.Conv_C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(arg.filter_spatial_lengths_),
std::end(arg.filter_spatial_lengths_));
int tensor_size =
std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies<int>{});
GridDesc_M0 a_grid_desc_m0_ = MakeDescriptor_M0<1>({tensor_size}, {1}, 240, 256);
GridDesc_M0 b_grid_desc_m0_ = MakeDescriptor_M0<1>({tensor_size}, {1}, 240, 256);
// run kernel for type conversion
void* p_c_grid_tmp_ = static_cast<void*>(arg.p_c_grid_);
InDataType* p_c_grid_tmp_bf16_ = static_cast<InDataType*>(p_c_grid_tmp_);
const auto Run_type_convert = [&](const auto& kernel) {
float elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(240),
dim3(256),
0,
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
p_c_grid_tmp_bf16_,
a_grid_desc_m0_,
b_grid_desc_m0_,
TypeConvertFunctor{});
return elapsed_time;
};
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value) if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{ {
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
...@@ -983,7 +1096,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -983,7 +1096,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< const auto kernel_type_convert =
kernel_unary_elementwise_1d<GridwiseUEltwise,
AccDataType,
InDataType,
GridDesc_M0,
TypeConvertFunctor>;
const auto kernel_conv = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatBf16Splitk, GridwiseGemmAtomicAddFloatBf16Splitk,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
AccDataType, AccDataType,
...@@ -997,7 +1117,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -997,7 +1117,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
true>; true>;
Run(kernel); Run_bf16_splitk(kernel_conv);
ave_time += Run_type_convert(kernel_type_convert);
} }
} }
else else
...@@ -1036,7 +1157,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1036,7 +1157,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
false>; false>;
Run(kernel); Run_bf16_splitk(kernel);
} }
} }
} }
...@@ -1319,6 +1440,11 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1319,6 +1440,11 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{ {
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg)); return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
} }
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->p_c_workspace_grid_ = workspace_ptr;
}
}; };
} // namespace device } // namespace device
......
...@@ -111,15 +111,6 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -111,15 +111,6 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
} }
else else
{ {
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment