"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "f71d3cfc7b2747a95d2305cf12af8dfe7e9a3e33"
Commit 6a2d7c9f authored by danyao12's avatar danyao12
Browse files

fwd mqa/gqa

parent c459f488
......@@ -75,7 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
static constexpr bool Deterministic = false;
#if(DIM <= 32)
using DeviceGemmInstance =
......
......@@ -18,7 +18,8 @@ int run(int argc, char* argv[])
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7;
ck::index_t G1 = 13;
ck::index_t G1 = 12; // h_q
ck::index_t G2 = 12; // h_kv
bool input_permute = false;
bool output_permute = true;
......@@ -37,7 +38,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
else if(argc == 14)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -49,20 +50,21 @@ int run(int argc, char* argv[])
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
p_drop = std::stof(argv[10]);
p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[13]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg11: p_drop\n");
printf("arg12 to 13: input / output permute\n");
exit(0);
}
......@@ -77,17 +79,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
......@@ -286,11 +288,19 @@ int run(int argc, char* argv[])
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
b0_g_k_n.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
b1_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
});
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
......
......@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
bool output_permute = true;
float p_drop = 0.2;
int h_ratio = 1; // G1 / G2
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
else if(argc == 8)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
p_drop = std::stoi(argv[4]);
input_permute = std::stoi(argv[5]);
output_permute = std::stoi(argv[6]);
h_ratio = std::stof(argv[5]);
input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[7]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 5: input / output permute\n");
printf("arg4: p_drop\n");
printf("arg5: h_ratio\n");
printf("arg6 to 7: input / output permute\n");
exit(0);
}
......@@ -88,7 +92,8 @@ int run(int argc, char* argv[])
int K = DIM;
int O = DIM;
int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1;
int G2 = rand() % 5 + 1;
int G1 = G2 * h_ratio;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O});
......@@ -98,17 +103,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
......@@ -253,7 +258,8 @@ int run(int argc, char* argv[])
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
......@@ -296,7 +302,8 @@ int run(int argc, char* argv[])
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
......@@ -350,11 +357,19 @@ int run(int argc, char* argv[])
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
b0_g_k_n.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
b1_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
});
z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) {
......
......@@ -134,6 +134,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
......
......@@ -254,6 +254,7 @@ __global__ void
ignore = ygrad_grid_desc_o0_m_o1;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = h_ratio;
ignore = nblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
......
......@@ -255,6 +255,7 @@ __global__ void
ignore = ygrad_grid_desc_m0_o_m1;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = h_ratio;
ignore = nblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
......
......@@ -78,6 +78,7 @@ __global__ void
const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t h_ratio,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
......@@ -94,13 +95,14 @@ __global__ void
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(gkv_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
......@@ -211,6 +213,7 @@ __global__ void
ignore = lse_grid_desc_m;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = h_ratio;
ignore = mblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
......@@ -662,7 +665,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
h_ratio_{c_grid_desc_g_m_n_.GetLength(I0) / b_grid_desc_g_n_k_.GetLength(I0)}
{
// TODO ANT: implement bias addition
ignore = p_acc1_biases;
......@@ -736,10 +740,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_m_n_: " << d0_grid_desc_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_m_n_.GetLength(I1) << '\n';
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
......@@ -802,6 +804,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_;
index_t h_ratio_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
......@@ -900,6 +903,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.lse_grid_desc_m_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.h_ratio_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
......@@ -1014,12 +1018,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
{
return false;
}
......
......@@ -209,6 +209,7 @@ __global__ void
#else
ignore = group_kernel_args;
ignore = group_count;
ignore = h_ratio;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......
......@@ -208,6 +208,7 @@ __global__ void
#else
ignore = group_kernel_args;
ignore = group_count;
ignore = h_ratio;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......
......@@ -44,6 +44,7 @@ __global__ void
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const index_t h_ratio,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
......@@ -88,13 +89,14 @@ __global__ void
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
......@@ -194,6 +196,7 @@ __global__ void
#else
ignore = group_kernel_args;
ignore = group_count;
ignore = h_ratio;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -415,7 +418,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
......@@ -424,7 +426,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
......@@ -655,6 +656,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
......@@ -679,12 +682,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
h_ratio_{h_ratio}
{
ignore = p_acc1_biases_vec;
// TODO ANT: implement bias addition
......@@ -855,6 +860,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n,
b_grid_desc_g_n_k,
c_grid_desc_g_m_n,
d0_n_length_stride});
}
......@@ -880,6 +887,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
index_t h_ratio_;
float p_dropout_;
uint8_t p_dropout_in_uint8_t_;
unsigned long long seed_;
......@@ -969,6 +977,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_,
arg.h_ratio_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
......@@ -1091,11 +1100,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = device_arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
if(!(c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
c_g / b_g == arg.h_ratio_))
{
return false;
}
......@@ -1203,6 +1215,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_a_vec,
......@@ -1220,6 +1233,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op,
c_element_op,
p_dropout,
h_ratio,
seeds};
}
......@@ -1242,6 +1256,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(p_a_vec,
......@@ -1259,6 +1274,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op,
c_element_op,
p_dropout,
h_ratio,
seeds);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment