Commit 475b6d1f authored by danyao12's avatar danyao12 Committed by wunhuang
Browse files

refactor grouped bwd example and fix some bugs

parent d0f055e9
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define RANGE_HDKO 2 // 0~2 #define RANGE_HDKO 1 // 0~2
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -523,7 +523,7 @@ int run(int argc, char* argv[]) ...@@ -523,7 +523,7 @@ int run(int argc, char* argv[])
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); alpha = std::stof(argv[10]);
p_drop = std::stof(argv[11]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[12]); input_permute = std::stoi(argv[12]);
...@@ -540,9 +540,9 @@ int run(int argc, char* argv[]) ...@@ -540,9 +540,9 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
...@@ -678,7 +678,6 @@ int run(int argc, char* argv[]) ...@@ -678,7 +678,6 @@ int run(int argc, char* argv[])
// = 0 // = 0
} }
// calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
......
...@@ -98,7 +98,7 @@ __global__ void ...@@ -98,7 +98,7 @@ __global__ void
unsigned short* z_matrix_ptr = unsigned short* z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
...@@ -211,7 +211,7 @@ template <index_t NumDimG, ...@@ -211,7 +211,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -223,7 +223,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -223,7 +223,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -448,7 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -448,7 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
} }
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
...@@ -711,7 +710,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -711,7 +710,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
} }
grid_size_ = 0; grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const DataType*>(p_As[i]); const auto p_a_grid = static_cast<const DataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]);
...@@ -895,7 +894,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -895,7 +894,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// for(std::size_t i = 0; i < arg.group_count_; i++) // for(std::size_t i = 0; i < arg.group_count_; i++)
// { // {
// const auto K = // const auto K =
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); // arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
// const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); // const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
// all_has_main_k_block_loop &= y; // all_has_main_k_block_loop &= y;
// some_has_main_k_block_loop |= y; // some_has_main_k_block_loop |= y;
...@@ -976,7 +976,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -976,7 +976,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
return false; return false;
} }
for(std::size_t i = 0; i < arg.group_count_; i++) for(index_t i = 0; i < arg.group_count_; i++)
{ {
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
...@@ -986,7 +986,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -986,7 +986,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
...@@ -1160,7 +1161,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1160,7 +1161,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1" str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -98,7 +98,7 @@ __global__ void ...@@ -98,7 +98,7 @@ __global__ void
unsigned short* z_matrix_ptr = unsigned short* z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
...@@ -703,7 +703,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -703,7 +703,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
grid_size_ = 0; grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const DataType*>(p_As[i]); const auto p_a_grid = static_cast<const DataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]);
...@@ -884,10 +884,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -884,10 +884,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
bool all_has_main_k_block_loop = true; bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false; bool some_has_main_k_block_loop = false;
for(std::size_t i = 0; i < arg.group_count_; i++) for(index_t i = 0; i < arg.group_count_; i++)
{ {
const auto K = const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
all_has_main_k_block_loop &= y; all_has_main_k_block_loop &= y;
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
...@@ -968,7 +968,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -968,7 +968,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
return false; return false;
} }
for(std::size_t i = 0; i < arg.group_count_; i++) for(index_t i = 0; i < arg.group_count_; i++)
{ {
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment