"...composable_kernel.git" did not exist on "e06b987116261634c72bad8d760c99ef74c8cd32"
Commit 6a2d7c9f authored by danyao12's avatar danyao12
Browse files

fwd mqa/gqa

parent c459f488
...@@ -75,7 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia ...@@ -75,7 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true; static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
......
...@@ -18,7 +18,8 @@ int run(int argc, char* argv[]) ...@@ -18,7 +18,8 @@ int run(int argc, char* argv[])
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7; ck::index_t G0 = 7;
ck::index_t G1 = 13; ck::index_t G1 = 12; // h_q
ck::index_t G2 = 12; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
...@@ -37,7 +38,7 @@ int run(int argc, char* argv[]) ...@@ -37,7 +38,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -49,20 +50,21 @@ int run(int argc, char* argv[]) ...@@ -49,20 +50,21 @@ int run(int argc, char* argv[])
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -77,17 +79,17 @@ int run(int argc, char* argv[]) ...@@ -77,17 +79,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
...@@ -286,11 +288,19 @@ int run(int argc, char* argv[]) ...@@ -286,11 +288,19 @@ int run(int argc, char* argv[])
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_g_n_o.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
}); });
z_gs_ms_ns.ForEach([&](auto& self, auto idx) { z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
......
...@@ -11,6 +11,7 @@ int run(int argc, char* argv[]) ...@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
bool output_permute = true; bool output_permute = true;
float p_drop = 0.2; float p_drop = 0.2;
int h_ratio = 1; // G1 / G2
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -24,22 +25,25 @@ int run(int argc, char* argv[]) ...@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 7) else if(argc == 8)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
p_drop = std::stoi(argv[4]); p_drop = std::stoi(argv[4]);
input_permute = std::stoi(argv[5]); h_ratio = std::stof(argv[5]);
output_permute = std::stoi(argv[6]); input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[7]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 5: input / output permute\n"); printf("arg4: p_drop\n");
printf("arg5: h_ratio\n");
printf("arg6 to 7: input / output permute\n");
exit(0); exit(0);
} }
...@@ -88,7 +92,8 @@ int run(int argc, char* argv[]) ...@@ -88,7 +92,8 @@ int run(int argc, char* argv[])
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G2 = rand() % 5 + 1;
int G1 = G2 * h_ratio;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O}); g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O});
...@@ -98,17 +103,17 @@ int run(int argc, char* argv[]) ...@@ -98,17 +103,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
...@@ -253,7 +258,8 @@ int run(int argc, char* argv[]) ...@@ -253,7 +258,8 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
...@@ -296,7 +302,8 @@ int run(int argc, char* argv[]) ...@@ -296,7 +302,8 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
...@@ -350,11 +357,19 @@ int run(int argc, char* argv[]) ...@@ -350,11 +357,19 @@ int run(int argc, char* argv[])
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_g_n_o.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
}); });
z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) { z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) {
......
...@@ -134,6 +134,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -134,6 +134,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) = 0; std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
......
...@@ -254,6 +254,7 @@ __global__ void ...@@ -254,6 +254,7 @@ __global__ void
ignore = ygrad_grid_desc_o0_m_o1; ignore = ygrad_grid_desc_o0_m_o1;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = h_ratio;
ignore = nblock; ignore = nblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
......
...@@ -255,6 +255,7 @@ __global__ void ...@@ -255,6 +255,7 @@ __global__ void
ignore = ygrad_grid_desc_m0_o_m1; ignore = ygrad_grid_desc_m0_o_m1;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = h_ratio;
ignore = nblock; ignore = nblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
......
...@@ -78,6 +78,7 @@ __global__ void ...@@ -78,6 +78,7 @@ __global__ void
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t h_ratio,
const index_t mblock, const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
...@@ -94,13 +95,14 @@ __global__ void ...@@ -94,13 +95,14 @@ __global__ void
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(gkv_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -211,6 +213,7 @@ __global__ void ...@@ -211,6 +213,7 @@ __global__ void
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = h_ratio;
ignore = mblock; ignore = mblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
...@@ -662,7 +665,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -662,7 +665,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)} batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
h_ratio_{c_grid_desc_g_m_n_.GetLength(I0) / b_grid_desc_g_n_k_.GetLength(I0)}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc1_biases; ignore = p_acc1_biases;
...@@ -736,10 +740,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -736,10 +740,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", " std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", " << d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n'; << d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_m_n_: " << d0_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "d0_grid_desc_m_n_: " << d0_grid_desc_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_m_n_.GetLength(I1) << '\n'; << d0_grid_desc_m_n_.GetLength(I1) << '\n';
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", " << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
...@@ -802,6 +804,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -802,6 +804,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_; index_t batch_count_;
index_t h_ratio_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_; float p_dropout_;
...@@ -900,6 +903,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -900,6 +903,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.h_ratio_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_), arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
...@@ -1014,12 +1018,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1014,12 +1018,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1); const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
{ {
return false; return false;
} }
......
...@@ -209,6 +209,7 @@ __global__ void ...@@ -209,6 +209,7 @@ __global__ void
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
ignore = h_ratio;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = acc_element_op; ignore = acc_element_op;
......
...@@ -208,6 +208,7 @@ __global__ void ...@@ -208,6 +208,7 @@ __global__ void
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
ignore = h_ratio;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = acc_element_op; ignore = acc_element_op;
......
...@@ -44,6 +44,7 @@ __global__ void ...@@ -44,6 +44,7 @@ __global__ void
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2( kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const index_t h_ratio,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
...@@ -88,13 +89,14 @@ __global__ void ...@@ -88,13 +89,14 @@ __global__ void
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch)); (block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -194,6 +196,7 @@ __global__ void ...@@ -194,6 +196,7 @@ __global__ void
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
ignore = h_ratio;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = acc_element_op; ignore = acc_element_op;
...@@ -415,7 +418,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -415,7 +418,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths, return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides); acc0_biases_gs_ms_ns_strides);
} }
...@@ -424,7 +426,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -424,7 +426,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths, return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides); acc0_biases_gs_ms_ns_strides);
} }
...@@ -655,6 +656,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -655,6 +656,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for gridwise gemm check // for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
// raw data // raw data
std::vector<ck::index_t> d0_n_length_stride_; std::vector<ck::index_t> d0_n_length_stride_;
...@@ -679,12 +682,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -679,12 +682,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
h_ratio_{h_ratio}
{ {
ignore = p_acc1_biases_vec; ignore = p_acc1_biases_vec;
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
...@@ -855,6 +860,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -855,6 +860,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n, c_grid_desc_m_n,
b_grid_desc_g_n_k,
c_grid_desc_g_m_n,
d0_n_length_stride}); d0_n_length_stride});
} }
...@@ -880,6 +887,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -880,6 +887,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op_; B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t h_ratio_;
float p_dropout_; float p_dropout_;
uint8_t p_dropout_in_uint8_t_; uint8_t p_dropout_in_uint8_t_;
unsigned long long seed_; unsigned long long seed_;
...@@ -969,6 +977,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -969,6 +977,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_, arg.group_count_,
arg.h_ratio_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_, arg.acc_element_op_,
...@@ -1091,11 +1100,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1091,11 +1100,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = device_arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0); const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1); const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
c_g / b_g == arg.h_ratio_))
{ {
return false; return false;
} }
...@@ -1203,6 +1215,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1203,6 +1215,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a_vec, return Argument{p_a_vec,
...@@ -1220,6 +1233,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1220,6 +1233,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_dropout, p_dropout,
h_ratio,
seeds}; seeds};
} }
...@@ -1242,6 +1256,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1242,6 +1256,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) override std::tuple<unsigned long long, unsigned long long> seeds) override
{ {
return std::make_unique<Argument>(p_a_vec, return std::make_unique<Argument>(p_a_vec,
...@@ -1259,6 +1274,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1259,6 +1274,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_dropout, p_dropout,
h_ratio,
seeds); seeds);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment