"tests/vscode:/vscode.git/clone" did not exist on "4f6399bedd28a9af0e18b1ebf84c38eb07ffa769"
Unverified Commit 561ec12f authored by Shaojie WANG's avatar Shaojie WANG Committed by GitHub
Browse files

example for convnd bwd weight bf16 splitk (#265)

* add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight

* add bwd weight for bf16: init

* remove redundant compute

* use datatype and split k to check whether a workspace is used

* remove unused computation for work space size

* add some code for bfp16

* add device/grid unary op

* add unary type convert to bwd-weight example

* support bf16 splitk kernel for convnd bwd weight

* 1. remove comments. 2. add checkvalidity. 3. add gridsize computation

* add workspace size check

* fix format

* change function name
parent fb9b6b1e
add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp)
add_example_executable(example_convnd_bwd_weight_xdl_bf16_splitk convnd_bwd_weight_xdl_bf16_splitk.cpp)
target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util)
target_link_libraries(example_convnd_bwd_weight_xdl_bf16_splitk PRIVATE conv_util)
\ No newline at end of file
......@@ -297,43 +297,7 @@ int main(int argc, char* argv[])
split_k);
// alloc work space
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get());
float ave_time = 0.f;
if(std::is_same<InDataType, ck::bhalf_t>::value && split_k > 1)
{
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
wei_work_space_device_buf.SetZero();
argument = conv->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(wei_work_space_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N_,
params.K_,
params.C_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{},
split_k);
if(!conv->IsSupportedArgument(argument.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
}
else
{
if(!conv->IsSupportedArgument(argument.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
......@@ -342,7 +306,6 @@ int main(int argc, char* argv[])
return 1;
}
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
}
std::size_t flop = ck::utils::conv::get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
......
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "conv_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "tensor_layout.hpp"
#include "element_wise_operation.hpp"
#include "device_unary_elementwise.hpp"
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "reference_conv_backward_weight.hpp"
using InDataType = ck::bhalf_t;
using WeiDataType = ck::bhalf_t;
using OutDataType = ck::bhalf_t;
using AccDataType = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryTypeConvert = ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
using DeviceUnaryElementwiseTypeConvertInstance = ck::tensor_operation::device::
DeviceUnaryElementwise<AccDataType, WeiDataType, UnaryTypeConvert, 1, 4>;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
using DeviceConvBwdWeightBasePtr =
ck::tensor_operation::device::DeviceConvBwdWeightPtr<InElementOp, WeiElementOp, OutElementOp>;
// clang-format off
template <ck::index_t NumDimSpatial>
using DeviceConvndBwdWeightInstance_bf16_splitk = ck::tensor_operation::device::
DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
InDataType, // InDataType
AccDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
NumDimSpatial, // NumDimSpatial
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template <ck::index_t NumDimSpatial>
using ReferenceConvBwdWeightInstance =
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
NumDimSpatial>;
template <typename HostTensorB, typename HostTensorA, typename Functor>
void host_elementwise(HostTensorB& B,
const HostTensorA& A,
const std::vector<std::size_t>& shape,
Functor functor)
{
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
std::cout << __LINE__ << ":" << tensor_size << ", " << A.mData[0] << std::endl;
for(std::size_t n = 0; n < tensor_size; ++n)
{
B.mData[n] = functor(A.mData[n]);
}
}
void print_use_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: is show log (0=no, 1=yes)\n"
<< "arg5: split-k : in this example split-k must be larger than 1\n"
<< "arg6: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl;
}
ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
{
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck::utils::conv::ConvParams params;
int arg_idx = 7;
params.num_dim_spatial_ = num_dim_spatial;
params.N_ = std::stoi(argv[arg_idx++]);
params.K_ = std::stoi(argv[arg_idx++]);
params.C_ = std::stoi(argv[arg_idx++]);
params.filter_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
}
params.input_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
}
params.conv_filter_strides_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
}
params.conv_filter_dilations_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
}
params.input_left_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
}
params.input_right_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
}
return params;
}
DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 3: {
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<3>>();
}
case 2: {
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<2>>();
}
case 1: {
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<1>>();
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int num_dim_spatial = 2;
int do_log = 0;
int split_k = 2;
ck::utils::conv::ConvParams params;
params.C_ = 128;
if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
do_log = std::stoi(argv[4]);
split_k = std::stoi(argv[5]);
}
else if(argc > 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
do_log = std::stoi(argv[4]);
split_k = std::stoi(argv[5]);
num_dim_spatial = std::stoi(argv[6]);
// check args number
int conv_args = 3 + num_dim_spatial * 6;
int cmdline_nargs = conv_args + 7;
if(cmdline_nargs != argc)
{
print_use_msg();
exit(1);
}
params = parse_conv_params(num_dim_spatial, argv);
}
else if(argc != 1)
{
print_use_msg();
exit(1);
}
if(split_k <= 1)
{
print_use_msg();
exit(1);
}
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.C_)};
input_dims.insert(std::end(input_dims),
std::begin(params.input_spatial_lengths_),
std::end(params.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
static_cast<std::size_t>(params.C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(params.filter_spatial_lengths_),
std::end(params.filter_spatial_lengths_));
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.K_)};
output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths),
std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi(
ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x_host_result(
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x_device_result(
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<OutDataType> out_n_k_ho_wo(
ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
break;
default:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) *
wei_k_c_y_x_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
// reset input to zero
wei_device_buf.SetZero();
// do GEMM
auto conv = get_conv_instance(num_dim_spatial);
auto invoker = conv->MakeInvokerPointer();
auto argument =
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N_,
params.K_,
params.C_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{},
split_k);
// alloc work space
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get());
if(bwd_weight_workspace_size <= 0)
{
print_use_msg();
exit(1);
}
float conv_ave_time = 0.f;
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
wei_work_space_device_buf.SetZero();
conv->SetWorkSpacePointer(argument.get(), wei_work_space_device_buf.GetDeviceBuffer());
if(!conv->IsSupportedArgument(argument.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = ck::utils::conv::get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
params.N_,
params.C_,
params.K_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / conv_ave_time;
float gb_per_sec = num_btype / 1.E6 / conv_ave_time;
std::cout << "Perf: conv: " << conv_ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << std::endl;
if(do_verification)
{
auto verify_f = [&](const auto& ref_conv) {
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_c_y_x_host_result,
out_n_k_ho_wo,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
if(do_log)
{
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
<< std::endl;
}
return ck::utils::check_err(wei_k_c_y_x_device_result.mData,
wei_k_c_y_x_host_result.mData)
? 0
: 1;
};
switch(num_dim_spatial)
{
case 3: {
auto ref_conv = ReferenceConvBwdWeightInstance<3>();
verify_f(ref_conv);
break;
}
case 2: {
auto ref_conv = ReferenceConvBwdWeightInstance<2>();
verify_f(ref_conv);
break;
}
case 1: {
auto ref_conv = ReferenceConvBwdWeightInstance<1>();
verify_f(ref_conv);
break;
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
return 0;
}
......@@ -11,6 +11,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
#include "gridwise_unary_elementwise_1d.hpp"
namespace ck {
namespace tensor_operation {
......@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
1);
}
// type convert descs
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * 4;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
}
template <index_t Dim>
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1)
{
const auto desc_m0 = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
}
else
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
}
using TypeConvertFunctor =
ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
using GridwiseUEltwise =
GridwiseUnaryElementwise_1D<AccDataType, InDataType, GridDesc_M0, TypeConvertFunctor, 4>;
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
......@@ -733,6 +782,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
true,
true>;
using GridwiseGemmAtomicAddFloatBf16Splitk = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
AccDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXdl,
NPerXdl,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
......@@ -802,6 +900,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// init work space
p_c_workspace_grid_ = nullptr;
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
......@@ -838,6 +939,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
index_t k_batch_;
// external work space
void* p_c_workspace_grid_;
};
// Invoker
......@@ -910,9 +1014,77 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.block_2_ctile_map_);
};
// run kernel for bf16 with splitk
const auto run_bf16_splitk = [&](const auto& kernel) {
hipGetErrorString(hipMemset(
arg.p_c_workspace_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(AccDataType)));
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
// kernel for type conversion
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(arg.Conv_K_),
static_cast<std::size_t>(arg.Conv_C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(arg.filter_spatial_lengths_),
std::end(arg.filter_spatial_lengths_));
int tensor_size =
std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies<int>{});
const index_t type_convert_grid_size = GridwiseUEltwise::CalculateGridSize(tensor_size);
GridDesc_M0 a_grid_desc_m0_ =
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
GridDesc_M0 b_grid_desc_m0_ =
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
if(!GridwiseUEltwise::CheckValidity(a_grid_desc_m0_, b_grid_desc_m0_))
{
throw std::runtime_error("wrong! GridwiseUnaryElementwise_1D has invalid setting");
}
// run kernel for type conversion
void* p_c_grid_tmp_ = static_cast<void*>(arg.p_c_grid_);
InDataType* p_c_grid_tmp_bf16_ = static_cast<InDataType*>(p_c_grid_tmp_);
const auto Run_type_convert = [&](const auto& kernel) {
float elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(type_convert_grid_size),
dim3(256),
0,
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
p_c_grid_tmp_bf16_,
a_grid_desc_m0_,
b_grid_desc_m0_,
TypeConvertFunctor{});
return elapsed_time;
};
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
if(has_main_k0_block_loop)
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
......@@ -920,7 +1092,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
......@@ -930,6 +1103,35 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
Run(kernel);
}
else
{
const auto kernel_type_convert =
kernel_unary_elementwise_1d<GridwiseUEltwise,
AccDataType,
InDataType,
GridDesc_M0,
TypeConvertFunctor>;
const auto kernel_conv = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatBf16Splitk,
ADataType, // TODO: distiguish A/B datatype
AccDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
true>;
run_bf16_splitk(kernel_conv);
ave_time += Run_type_convert(kernel_type_convert);
}
}
else
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
......@@ -937,7 +1139,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
......@@ -946,6 +1149,25 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatBf16Splitk,
ADataType, // TODO: distiguish A/B datatype
AccDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
run_bf16_splitk(kernel);
}
}
}
else
{
......@@ -1226,6 +1448,11 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->p_c_workspace_grid_ = workspace_ptr;
}
};
} // namespace device
......
#pragma once
#include <iostream>
#include <vector>
#include "device.hpp"
#include "device_base.hpp"
#include "gridwise_unary_elementwise_1d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename ElementwiseFunctor,
index_t Dim,
index_t ScalarPerVector>
struct DeviceUnaryElementwise : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
}
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1)
{
const auto desc_m0 = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
}
else
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using GridwiseUEltwise = GridwiseUnaryElementwise_1D<ADataType,
BDataType,
GridDesc_M0,
ElementwiseFunctor,
ScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
BDataType* p_b,
const std::vector<index_t>& shape,
const std::vector<index_t>& stride_a,
const std::vector<index_t>& stride_b,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
shape_(shape),
functor_(functor),
blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future
{
index_t tensor_size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size);
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
}
const ADataType* p_a_;
BDataType* p_b_;
std::vector<int> shape_;
GridDesc_M0 a_grid_desc_m0_;
GridDesc_M0 b_grid_desc_m0_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_unary_elementwise_1d<GridwiseUEltwise,
ADataType,
BDataType,
GridDesc_M0,
ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.p_a_,
arg.p_b_,
arg.a_grid_desc_m0_,
arg.b_grid_desc_m0_,
arg.functor_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->shape_.back() % ScalarPerVector != 0)
return false;
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
void* p_b,
std::vector<index_t> shape,
std::vector<index_t> stride_a,
std::vector<index_t> stride_b,
ElementwiseFunctor functor)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<BDataType*>(p_b),
shape,
stride_a,
stride_b,
functor);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBinaryElementwise"
<< "<"
<< "ScalarPerVector = " << ScalarPerVector
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -346,6 +346,27 @@ struct UnarySqrt<double, double>
};
};
template <typename Y, typename X>
struct UnaryTypeConvert;
template <>
struct UnaryTypeConvert<float, ck::bhalf_t>
{
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
{
y = ck::type_convert<float, ck::bhalf_t>(x);
};
};
template <>
struct UnaryTypeConvert<ck::bhalf_t, float>
{
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
{
y = ck::type_convert<ck::bhalf_t, float>(x);
};
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -791,8 +791,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
void* p_shared = static_cast<void*>(p_shared_block);
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared_block),
static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(M1 == MWave, "");
......
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwiseUEltwise,
typename ADataType,
typename BDataType,
typename GridDesc_M0,
typename ElementwiseFunctor>
__global__ void kernel_unary_elementwise_1d(const ADataType* __restrict__ p_a_global,
BDataType* __restrict__ p_b_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const ElementwiseFunctor functor)
{
GridwiseUEltwise::Run(p_a_global, p_b_global, a_grid_desc_m0, b_grid_desc_m0, functor);
}
template <typename ADataType,
typename BDataType,
typename GridDesc_M0,
typename ElementwiseFunctor,
index_t ScalarPerVector>
struct GridwiseUnaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector);
}
__host__ __device__ static constexpr bool CheckValidity(const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0)
{
return a_grid_desc_m0.GetLength(I0) == b_grid_desc_m0.GetLength(I0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(const index_t tensor_size)
{
const index_t grid_size = math::integer_divide_ceil(tensor_size, 256 * ScalarPerVector);
return grid_size;
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
BDataType* __restrict__ p_b_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, ScalarPerVector, true> b_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_store_global_offset};
auto b_global_write =
ThreadwiseTensorSliceTransfer_v1r3<BDataType,
BDataType,
decltype(thread_desc_m0),
GridDesc_M0,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
ScalarPerVector,
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
b_grid_desc_m0, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto m0 = b_grid_desc_m0.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step);
do
{
// read and process ScalarPerVector elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
functor(b_thread_buf(Number<offset>{}), a_thread_buf(Number<offset>{}));
});
b_global_write.Run(thread_desc_m0,
make_tuple(I0), // SrcSliceOriginIdx
b_thread_buf,
b_grid_desc_m0,
b_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
b_global_write.MoveDstSliceWindow(b_grid_desc_m0, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment