Commit e327363f authored by fsx950223's avatar fsx950223
Browse files

fix bugss

parent 0fe4fb38
...@@ -306,17 +306,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -306,17 +306,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
// P_dropped
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P * V // Y = P_dropout * V
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument( auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{}); p_drop_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument); ref_gemm1_invoker.Run(ref_gemm1_argument);
} }
...@@ -425,8 +426,8 @@ int run(int argc, char* argv[]) ...@@ -425,8 +426,8 @@ int run(int argc, char* argv[])
{ {
int M = 128 * (rand() % 4 + 1); int M = 128 * (rand() % 4 + 1);
int N = 128 * (rand() % 4 + 1); int N = 128 * (rand() % 4 + 1);
int K = 64; int K = 128;
int O = 64; int O = 128;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 2 + 1; int G1 = rand() % 2 + 1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
...@@ -720,6 +721,36 @@ int run(int argc, char* argv[]) ...@@ -720,6 +721,36 @@ int run(int argc, char* argv[])
kgrad_tensors_device[i]->SetZero(); kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero(); vgrad_tensors_device[i]->SetZero();
} }
// p_z = std::vector<void*>(p_z.size(), nullptr);
// argument =
// gemm.MakeArgument(p_q,
// p_k,
// p_z,
// p_v,
// p_y,
// p_lse,
// p_ygrad,
// p_qgrad,
// p_kgrad,
// p_vgrad,
// {}, // std::array<void*, 1> p_acc0_biases;
// {}, // std::array<void*, 1> p_acc1_biases;
// problem_descs,
// QKVElementOp{},
// QKVElementOp{},
// Scale{alpha},
// QKVElementOp{},
// YElementOp{},
// p_drop,
// std::tuple<unsigned long long, unsigned long long>(seed, offset));
// DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
// gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
// if(!gemm.IsSupportedArgument(argument))
// {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
// }
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
......
...@@ -98,7 +98,7 @@ __global__ void ...@@ -98,7 +98,7 @@ __global__ void
unsigned short* z_matrix_ptr = unsigned short* z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
...@@ -379,56 +379,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -379,56 +379,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec,
const std::vector<index_t>& v_gs_os_ns_strides_vec)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new];
}
const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
// N_O to O0_N_O1; to refactor
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
// //
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// //
...@@ -488,7 +438,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -488,7 +438,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
...@@ -769,7 +719,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -769,7 +719,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto z_grid_desc_m_n = DeviceOp::MakeZGridDescriptor_M_N( const auto z_grid_desc_m_n = DeviceOp::MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1( const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N( const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N(
...@@ -927,16 +877,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -927,16 +877,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
throw std::runtime_error("wrong! unsupported argument"); throw std::runtime_error("wrong! unsupported argument");
} }
// bool all_has_main_k_block_loop = true; bool all_has_main_k_block_loop = true;
// bool some_has_main_k_block_loop = false; bool some_has_main_k_block_loop = false;
// for(std::size_t i = 0; i < arg.group_count_; i++) for(std::size_t i = 0; i < arg.group_count_; i++)
// { {
// const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * const auto K =
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
// const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
// all_has_main_k_block_loop &= y; all_has_main_k_block_loop &= y;
// some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
// } }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipGetErrorString(hipMemcpy(arg.p_workspace_,
arg.group_kernel_args_.data(), arg.group_kernel_args_.data(),
...@@ -976,19 +926,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -976,19 +926,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
// if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
// { {
// ave_time = launch_kernel(integral_constant<bool, true>{}); ave_time = launch_kernel(integral_constant<bool, true>{});
// } }
// else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
// { {
ave_time = launch_kernel(integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, false>{});
// } }
// else else
// { {
// throw std::runtime_error("wrong! all gemm problems have to simultaneously meet " throw std::runtime_error("wrong! all gemm problems have to simultaneously meet "
// "has_main_k_block_loop or no_main_k_block_loop"); "has_main_k_block_loop or no_main_k_block_loop");
// } }
return ave_time; return ave_time;
} }
...@@ -1023,8 +973,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1023,8 +973,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment