Commit c459f488 authored by danyao12's avatar danyao12
Browse files

clean up codes

parent 174e013d
...@@ -15,8 +15,6 @@ add_example_executable(example_grouped_multihead_attention_forward_v2 grouped_mu ...@@ -15,8 +15,6 @@ add_example_executable(example_grouped_multihead_attention_forward_v2 grouped_mu
add_example_executable(example_batched_multihead_attention_forward_v2 batched_multihead_attention_forward_v2.cpp) add_example_executable(example_batched_multihead_attention_forward_v2 batched_multihead_attention_forward_v2.cpp)
add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_multihead_attention_backward_v2.cpp) add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp) add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp)
add_example_executable(example_grouped_multihead_attention_backward_v2_protro grouped_multihead_attention_backward_v2_protro.cpp)
add_example_executable(example_batched_multihead_attention_backward_v2_protro batched_multihead_attention_backward_v2_protro.cpp)
add_example_executable(example_grouped_multihead_attention_train_v2 grouped_multihead_attention_train_v2.cpp) add_example_executable(example_grouped_multihead_attention_train_v2 grouped_multihead_attention_train_v2.cpp)
add_example_executable(example_batched_multihead_attention_train_v2 batched_multihead_attention_train_v2.cpp) add_example_executable(example_batched_multihead_attention_train_v2 batched_multihead_attention_train_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp) add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
......
...@@ -276,7 +276,8 @@ int run(int argc, char* argv[]) ...@@ -276,7 +276,8 @@ int run(int argc, char* argv[])
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 4;
ck::index_t G1 = 6; ck::index_t G1 = 6; // h_q
ck::index_t G2 = 1; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -295,7 +296,7 @@ int run(int argc, char* argv[]) ...@@ -295,7 +296,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -307,20 +308,21 @@ int run(int argc, char* argv[]) ...@@ -307,20 +308,21 @@ int run(int argc, char* argv[])
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -338,6 +340,7 @@ int run(int argc, char* argv[]) ...@@ -338,6 +340,7 @@ int run(int argc, char* argv[])
std::cout << "O: " << O << std::endl; std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl; std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl; std::cout << "G1: " << G1 << std::endl;
std::cout << "G2: " << G2 << std::endl;
std::cout << "alpha: " << alpha << std::endl; std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl; std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl; std::cout << "output_permute: " << output_permute << std::endl;
...@@ -353,17 +356,17 @@ int run(int argc, char* argv[]) ...@@ -353,17 +356,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
...@@ -376,6 +379,18 @@ int run(int argc, char* argv[]) ...@@ -376,6 +379,18 @@ int run(int argc, char* argv[])
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
...@@ -392,6 +407,8 @@ int run(int argc, char* argv[]) ...@@ -392,6 +407,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
...@@ -399,6 +416,8 @@ int run(int argc, char* argv[]) ...@@ -399,6 +416,8 @@ int run(int argc, char* argv[])
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
...@@ -475,10 +494,20 @@ int run(int argc, char* argv[]) ...@@ -475,10 +494,20 @@ int run(int argc, char* argv[])
q_gs_ms_ks.ForEach( q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_g_n_k.ForEach([&](auto& self, auto idx) {
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); const size_t& g0 = idx[0] / G1;
v_gs_os_ns.ForEach( const size_t& g1 = idx[0] % G1;
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); const size_t& g2 = g1 / (G1 / G2);
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
});
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -488,8 +517,8 @@ int run(int argc, char* argv[]) ...@@ -488,8 +517,8 @@ int run(int argc, char* argv[])
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize()); DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(OutputDataType) * kgrad_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(OutputDataType) * vgrad_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
...@@ -528,6 +557,10 @@ int run(int argc, char* argv[]) ...@@ -528,6 +557,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -575,6 +608,10 @@ int run(int argc, char* argv[]) ...@@ -575,6 +608,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -757,12 +794,16 @@ int run(int argc, char* argv[]) ...@@ -757,12 +794,16 @@ int run(int argc, char* argv[])
#endif #endif
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data()); qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data()); kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
......
...@@ -272,6 +272,7 @@ int run(int argc, char* argv[]) ...@@ -272,6 +272,7 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM); float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.0; float p_drop = 0.0;
int h_ratio = 1; // G1 / G2
bool input_permute = true; bool input_permute = true;
bool output_permute = true; bool output_permute = true;
...@@ -289,25 +290,26 @@ int run(int argc, char* argv[]) ...@@ -289,25 +290,26 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 7) else if(argc == 8)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
p_drop = std::stof(argv[4]); p_drop = std::stof(argv[4]);
h_ratio = std::stof(argv[5]);
input_permute = std::stoi(argv[5]); input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[6]); output_permute = std::stoi(argv[7]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4: p_drop\n");
printf("arg10: scale (alpha)\n"); printf("arg5: h_ratio\n");
printf("arg11 to 12: input / output permute\n"); printf("arg6 to 7: input / output permute\n");
exit(0); exit(0);
} }
...@@ -372,24 +374,25 @@ int run(int argc, char* argv[]) ...@@ -372,24 +374,25 @@ int run(int argc, char* argv[])
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 4 + 1; int G0 = rand() % 4 + 1;
int G1 = rand() % 4 + 1; int G2 = rand() % 4 + 1;
int G1 = G2 * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
...@@ -402,6 +405,17 @@ int run(int argc, char* argv[]) ...@@ -402,6 +405,17 @@ int run(int argc, char* argv[])
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
...@@ -423,6 +437,10 @@ int run(int argc, char* argv[]) ...@@ -423,6 +437,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -446,6 +464,8 @@ int run(int argc, char* argv[]) ...@@ -446,6 +464,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
if(i < 4) if(i < 4)
{ {
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
...@@ -454,6 +474,8 @@ int run(int argc, char* argv[]) ...@@ -454,6 +474,8 @@ int run(int argc, char* argv[])
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
} }
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
...@@ -531,11 +553,19 @@ int run(int argc, char* argv[]) ...@@ -531,11 +553,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_g_n_k.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
}); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_g_n_o.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
}); });
q_g_m_ks.push_back(q_g_m_k); q_g_m_ks.push_back(q_g_m_k);
...@@ -554,6 +584,8 @@ int run(int argc, char* argv[]) ...@@ -554,6 +584,8 @@ int run(int argc, char* argv[])
z_tensors.push_back(z_gs_ms_ns); z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
kgrad_tensors.push_back(kgrad_gs_ns_ks);
vgrad_tensors.push_back(vgrad_gs_os_ns);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back( k_tensors_device.emplace_back(
...@@ -568,10 +600,10 @@ int run(int argc, char* argv[]) ...@@ -568,10 +600,10 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back( qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back( kgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize())); sizeof(OutputDataType) * kgrad_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back( vgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize())); sizeof(OutputDataType) * vgrad_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back( ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data()); q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
...@@ -613,6 +645,7 @@ int run(int argc, char* argv[]) ...@@ -613,6 +645,7 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -661,6 +694,7 @@ int run(int argc, char* argv[]) ...@@ -661,6 +694,7 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
...@@ -674,7 +708,7 @@ int run(int argc, char* argv[]) ...@@ -674,7 +708,7 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int G1 = v_tensors[i].GetLengths()[1]; int G1 = q_tensors[i].GetLengths()[1];
// copy z matirx data form device // copy z matirx data form device
z_tensors_device[i]->FromDevice(z_tensors[i].mData.data()); z_tensors_device[i]->FromDevice(z_tensors[i].mData.data());
z_tensors[i].ForEach([&](auto& self, auto idx) { z_tensors[i].ForEach([&](auto& self, auto idx) {
...@@ -711,8 +745,8 @@ int run(int argc, char* argv[]) ...@@ -711,8 +745,8 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int G0 = v_tensors[i].GetLengths()[0]; int G0 = q_tensors[i].GetLengths()[0];
int G1 = v_tensors[i].GetLengths()[1]; int G1 = q_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2]; int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3]; int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2]; int M = q_tensors[i].GetLengths()[2];
...@@ -770,17 +804,17 @@ int run(int argc, char* argv[]) ...@@ -770,17 +804,17 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data()); qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data()); kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
......
...@@ -74,16 +74,19 @@ __global__ void ...@@ -74,16 +74,19 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3, const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1, const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t h_ratio,
const index_t nblock, const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
...@@ -99,21 +102,26 @@ __global__ void ...@@ -99,21 +102,26 @@ __global__ void
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch // NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets // offsets
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(gkv_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const long_index_t bgrad_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBGradBasePtr(g_idx)));
const long_index_t b1grad_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1GradBasePtr(g_idx)));
ck::philox ph(seed, 0, offset); ck::philox ph(seed, 0, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
...@@ -149,9 +157,9 @@ __global__ void ...@@ -149,9 +157,9 @@ __global__ void
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -160,9 +168,11 @@ __global__ void ...@@ -160,9 +168,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
...@@ -187,9 +197,9 @@ __global__ void ...@@ -187,9 +197,9 @@ __global__ void
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -198,9 +208,11 @@ __global__ void ...@@ -198,9 +208,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
...@@ -232,9 +244,11 @@ __global__ void ...@@ -232,9 +244,11 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = bgrad_grid_desc_bk0_n_bk1;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3; ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = b1grad_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = ygrad_grid_desc_o0_m_o1; ignore = ygrad_grid_desc_o0_m_o1;
...@@ -603,6 +617,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -603,6 +617,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const BGridDesc_G_N_K& bgrad_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1grad_grid_desc_g_n_k,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
...@@ -610,6 +626,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -610,6 +626,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
bgrad_grid_desc_g_n_k_(bgrad_grid_desc_g_n_k),
b1grad_grid_desc_g_n_k_(b1grad_grid_desc_g_n_k),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
} }
...@@ -649,6 +667,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -649,6 +667,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return g_idx * static_cast<long_index_t>(BatchStrideLSE_); return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
} }
__host__ __device__ constexpr long_index_t GetBGradBasePtr(index_t g_idx) const
{
return bgrad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1GradBasePtr(index_t g_idx) const
{
return b1grad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
...@@ -656,6 +684,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -656,6 +684,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -755,6 +785,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -755,6 +785,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -784,9 +818,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -784,9 +818,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
bgrad_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
bgrad_gs_ns_ks_lengths, bgrad_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
b1grad_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
...@@ -805,6 +843,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -805,6 +843,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
bgrad_grid_desc_g_n_k_{Transform::MakeB0GridDescriptor_G_N_K(bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides)},
b1grad_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{}, y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -826,6 +868,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -826,6 +868,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
h_ratio_{c_grid_desc_g_m_n_.GetLength(I0) / b_grid_desc_g_n_k_.GetLength(I0)},
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
...@@ -864,6 +907,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -864,6 +907,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n_, z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
bgrad_grid_desc_g_n_k_,
b1grad_grid_desc_g_n_k_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
...@@ -887,7 +932,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -887,7 +932,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<< b_grid_desc_g_n_k_.GetLength(I1) << ", " << b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print(); // b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", " << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print(); // b1_grid_desc_g_n_k_.Print();
...@@ -900,6 +945,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -900,6 +945,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::cout << "ygrad_grid_desc_o0_m_o1_: " << ygrad_grid_desc_o0_m_o1_.GetLength(I0) std::cout << "ygrad_grid_desc_o0_m_o1_: " << ygrad_grid_desc_o0_m_o1_.GetLength(I0)
<< ", " << ygrad_grid_desc_o0_m_o1_.GetLength(I1) << ", " << ", " << ygrad_grid_desc_o0_m_o1_.GetLength(I1) << ", "
<< ygrad_grid_desc_o0_m_o1_.GetLength(I2) << '\n'; << ygrad_grid_desc_o0_m_o1_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
std::cout << "bgrad_grid_desc_g_n_k_: " << bgrad_grid_desc_g_n_k_.GetLength(I0) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// bgrad_grid_desc_g_n_k_.Print();
std::cout << "b1grad_grid_desc_g_n_k_: " << b1grad_grid_desc_g_n_k_.GetLength(I0)
<< ", " << b1grad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1grad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1grad_grid_desc_g_n_k_.Print();
} }
// pointers // pointers
...@@ -919,9 +975,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -919,9 +975,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_; KGridDesc_N_K k_grid_desc_n_k_;
...@@ -934,6 +992,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -934,6 +992,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_; y_grid_desc_mblock_mperblock_oblock_operblock_;
...@@ -961,6 +1021,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -961,6 +1021,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_; index_t batch_count_;
index_t h_ratio_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_; float p_drop_;
...@@ -1047,14 +1108,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1047,14 +1108,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.bgrad_grid_desc_bk0_n_bk1_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.b1grad_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_, arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.h_ratio_,
arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_), arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_),
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
...@@ -1108,13 +1172,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1108,13 +1172,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
{ {
return false; return false;
} }
...@@ -1203,6 +1269,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1203,6 +1269,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -1242,6 +1312,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1242,6 +1312,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides,
b1grad_gs_gemm1ns_gemm1ks_lengths,
b1grad_gs_gemm1ns_gemm1ks_strides,
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
...@@ -1285,6 +1359,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1285,6 +1359,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -1325,6 +1403,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1325,6 +1403,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides,
b1grad_gs_gemm1ns_gemm1ks_lengths,
b1grad_gs_gemm1ns_gemm1ks_strides,
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, acc1_bias_gs_ms_gemm1ns_lengths,
......
...@@ -44,6 +44,7 @@ __global__ void ...@@ -44,6 +44,7 @@ __global__ void
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1( kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const index_t h_ratio,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
...@@ -82,19 +83,26 @@ __global__ void ...@@ -82,19 +83,26 @@ __global__ void
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch)); (block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const long_index_t bgrad_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetBGradBasePtr(g_idx)));
const long_index_t b1grad_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1GradBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
...@@ -129,9 +137,9 @@ __global__ void ...@@ -129,9 +137,9 @@ __global__ void
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -140,9 +148,11 @@ __global__ void ...@@ -140,9 +148,11 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
...@@ -168,9 +178,9 @@ __global__ void ...@@ -168,9 +178,9 @@ __global__ void
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -179,9 +189,11 @@ __global__ void ...@@ -179,9 +189,11 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
...@@ -307,6 +319,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -307,6 +319,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> bgrad_gs_ns_ks_lengths;
std::vector<index_t> bgrad_gs_ns_ks_strides;
std::vector<index_t> b1grad_gs_gemm1ns_gemm1ks_lengths;
std::vector<index_t> b1grad_gs_gemm1ns_gemm1ks_strides;
std::vector<index_t> acc0_bias_gs_ms_ns_lengths; std::vector<index_t> acc0_bias_gs_ms_ns_lengths;
std::vector<index_t> acc0_bias_gs_ms_ns_strides; std::vector<index_t> acc0_bias_gs_ms_ns_strides;
...@@ -508,7 +526,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -508,7 +526,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
...@@ -517,7 +534,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -517,7 +534,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
...@@ -564,6 +580,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -564,6 +580,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const BGridDesc_G_N_K& bgrad_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1grad_grid_desc_g_n_k,
index_t batch_stride_lse) index_t batch_stride_lse)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
...@@ -571,6 +589,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -571,6 +589,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
bgrad_grid_desc_g_n_k_(bgrad_grid_desc_g_n_k),
b1grad_grid_desc_g_n_k_(b1grad_grid_desc_g_n_k),
batch_stride_lse_(batch_stride_lse) batch_stride_lse_(batch_stride_lse)
{ {
} }
...@@ -610,6 +630,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -610,6 +630,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return g_idx * static_cast<long_index_t>(batch_stride_lse_); return g_idx * static_cast<long_index_t>(batch_stride_lse_);
} }
__host__ __device__ constexpr long_index_t GetBGradBasePtr(index_t g_idx) const
{
return bgrad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1GradBasePtr(index_t g_idx) const
{
return b1grad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
...@@ -617,6 +647,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -617,6 +647,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
index_t batch_stride_lse_; index_t batch_stride_lse_;
}; };
...@@ -708,9 +740,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -708,9 +740,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
...@@ -745,6 +779,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -745,6 +779,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_gemm1nz_strides_;
// for gridwise gemm check // for gridwise gemm check
BGridDesc_G_N_K b_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t batch_count_; index_t batch_count_;
...@@ -776,13 +811,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -776,13 +811,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
p_dropout_{p_drop} p_dropout_{p_drop},
h_ratio_{h_ratio}
{ {
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -840,6 +877,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -840,6 +877,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto bgrad_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.bgrad_gs_ns_ks_lengths, problem_desc.bgrad_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths; std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
...@@ -862,6 +901,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -862,6 +901,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1( const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto b1grad_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1grad_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1grad_gs_gemm1ns_gemm1ks_strides);
const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N( const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
...@@ -885,6 +927,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -885,6 +927,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
const auto bgrad_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.bgrad_gs_ns_ks_lengths, problem_desc.bgrad_gs_ns_ks_strides);
const auto b1grad_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1grad_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1grad_gs_gemm1ns_gemm1ks_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock; y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
...@@ -918,6 +965,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -918,6 +965,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
bgrad_grid_desc_g_n_k,
b1grad_grid_desc_g_n_k,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
...@@ -945,9 +994,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -945,9 +994,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
z_grid_desc_m_n, z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
y_grid_desc_m_o, y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
...@@ -985,6 +1036,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -985,6 +1036,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc.b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, problem_desc.b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
b_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
batch_count, batch_count,
d0_n_length_stride}); d0_n_length_stride});
...@@ -1011,6 +1063,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1011,6 +1063,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
index_t grid_size_; index_t grid_size_;
index_t group_count_; index_t group_count_;
index_t h_ratio_;
std::vector<GroupKernelArg> group_kernel_args_; std::vector<GroupKernelArg> group_kernel_args_;
std::vector<GroupDeviceArg> group_device_args_; std::vector<GroupDeviceArg> group_device_args_;
...@@ -1070,6 +1123,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1070,6 +1123,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_, arg.group_count_,
arg.h_ratio_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_, arg.acc_element_op_,
...@@ -1138,13 +1192,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1138,13 +1192,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = device_arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n &&
c_g % b_g == 0 && c_g / b_g == arg.h_ratio_))
{ {
return false; return false;
} }
...@@ -1240,6 +1296,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1240,6 +1296,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_As, return Argument{p_As,
...@@ -1263,6 +1320,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1263,6 +1320,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds}; seeds};
} }
...@@ -1292,6 +1350,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1292,6 +1350,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(p_As, return std::make_unique<Argument>(p_As,
...@@ -1315,6 +1374,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1315,6 +1374,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds); seeds);
} }
......
...@@ -44,6 +44,7 @@ __global__ void ...@@ -44,6 +44,7 @@ __global__ void
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2( kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const index_t h_ratio,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
...@@ -82,19 +83,26 @@ __global__ void ...@@ -82,19 +83,26 @@ __global__ void
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch)); (block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const long_index_t bgrad_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetBGradBasePtr(g_idx)));
const long_index_t b1grad_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1GradBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
...@@ -128,9 +136,9 @@ __global__ void ...@@ -128,9 +136,9 @@ __global__ void
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -139,9 +147,11 @@ __global__ void ...@@ -139,9 +147,11 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
...@@ -167,9 +177,9 @@ __global__ void ...@@ -167,9 +177,9 @@ __global__ void
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -178,9 +188,11 @@ __global__ void ...@@ -178,9 +188,11 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
...@@ -313,6 +325,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -313,6 +325,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> bgrad_gs_ns_ks_lengths;
std::vector<index_t> bgrad_gs_ns_ks_strides;
std::vector<index_t> b1grad_gs_gemm1ns_gemm1ks_lengths;
std::vector<index_t> b1grad_gs_gemm1ns_gemm1ks_strides;
std::vector<index_t> acc0_bias_gs_ms_ns_lengths; std::vector<index_t> acc0_bias_gs_ms_ns_lengths;
std::vector<index_t> acc0_bias_gs_ms_ns_strides; std::vector<index_t> acc0_bias_gs_ms_ns_strides;
...@@ -570,7 +588,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -570,7 +588,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
...@@ -579,7 +596,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -579,7 +596,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
...@@ -626,6 +642,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -626,6 +642,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const BGridDesc_G_N_K& bgrad_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1grad_grid_desc_g_n_k,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
...@@ -633,6 +651,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -633,6 +651,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
bgrad_grid_desc_g_n_k_(bgrad_grid_desc_g_n_k),
b1grad_grid_desc_g_n_k_(b1grad_grid_desc_g_n_k),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
} }
...@@ -672,6 +692,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -672,6 +692,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return g_idx * static_cast<long_index_t>(BatchStrideLSE_); return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
} }
__host__ __device__ constexpr long_index_t GetBGradBasePtr(index_t g_idx) const
{
return bgrad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1GradBasePtr(index_t g_idx) const
{
return b1grad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
...@@ -679,6 +709,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -679,6 +709,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -778,9 +810,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -778,9 +810,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
...@@ -815,6 +849,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -815,6 +849,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_gemm1nz_strides_;
// for gridwise gemm check // for gridwise gemm check
BGridDesc_G_N_K b_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t batch_count_; index_t batch_count_;
...@@ -846,13 +881,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -846,13 +881,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
p_dropout_{p_drop} p_dropout_{p_drop},
h_ratio_{h_ratio}
{ {
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -910,6 +947,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -910,6 +947,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto bgrad_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.bgrad_gs_ns_ks_lengths, problem_desc.bgrad_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths; std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
...@@ -932,6 +971,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -932,6 +971,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1( const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto b1grad_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1grad_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1grad_gs_gemm1ns_gemm1ks_strides);
const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N( const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
...@@ -955,6 +997,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -955,6 +997,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
const auto bgrad_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.bgrad_gs_ns_ks_lengths, problem_desc.bgrad_gs_ns_ks_strides);
const auto b1grad_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1grad_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1grad_gs_gemm1ns_gemm1ks_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock; y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
...@@ -988,6 +1035,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -988,6 +1035,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
bgrad_grid_desc_g_n_k,
b1grad_grid_desc_g_n_k,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
...@@ -1015,9 +1064,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1015,9 +1064,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
z_grid_desc_m_n, z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
y_grid_desc_m_o, y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
...@@ -1055,6 +1106,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1055,6 +1106,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, problem_desc.b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
b_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
batch_count, batch_count,
d0_n_length_stride}); d0_n_length_stride});
...@@ -1081,6 +1133,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1081,6 +1133,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t grid_size_; index_t grid_size_;
index_t group_count_; index_t group_count_;
index_t h_ratio_;
std::vector<GroupKernelArg> group_kernel_args_; std::vector<GroupKernelArg> group_kernel_args_;
std::vector<GroupDeviceArg> group_device_args_; std::vector<GroupDeviceArg> group_device_args_;
...@@ -1139,6 +1192,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1139,6 +1192,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_, arg.group_count_,
arg.h_ratio_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_, arg.acc_element_op_,
...@@ -1207,13 +1261,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1207,13 +1261,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = device_arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n &&
c_g % b_g == 0 && c_g / b_g == arg.h_ratio_))
{ {
return false; return false;
} }
...@@ -1315,6 +1371,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1315,6 +1371,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_As, return Argument{p_As,
...@@ -1338,6 +1395,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1338,6 +1395,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds}; seeds};
} }
...@@ -1367,6 +1425,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1367,6 +1425,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(p_As, return std::make_unique<Argument>(p_As,
...@@ -1390,6 +1449,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1390,6 +1449,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds); seeds);
} }
......
...@@ -1521,10 +1521,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1521,10 +1521,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1, const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const KGridDesc_K0_N_K1& kgrad_grid_desc_k0_n_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3, const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1, const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const VGridDesc_O0_N_O1& vgrad_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
...@@ -1557,11 +1559,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1557,11 +1559,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize()); p_vgrad_grid, vgrad_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize()); p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize()); p_kgrad_grid, kgrad_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K] // divide block work by [N, K]
const auto block_work_idx = const auto block_work_idx =
...@@ -1711,7 +1713,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1711,7 +1713,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// dV: transform input and output tensor descriptors // dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock = auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_o0_n_o1); MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(vgrad_grid_desc_o0_n_o1);
// dK: A matrix blockwise copy // dK: A matrix blockwise copy
auto kgrad_gemm_tile_sgrad_blockwise_copy = auto kgrad_gemm_tile_sgrad_blockwise_copy =
...@@ -1740,7 +1742,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1740,7 +1742,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// dK: transform input and output tensor descriptors // dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock = auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(k_grid_desc_k0_n_k1); MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(kgrad_grid_desc_k0_n_k1);
// //
// set up dQ Gemm (type 3 crr) // set up dQ Gemm (type 3 crr)
......
...@@ -1587,10 +1587,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1587,10 +1587,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1, const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const KGridDesc_K0_N_K1& kgrad_grid_desc_k0_n_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3, const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1, const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const VGridDesc_O0_N_O1& vgrad_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
...@@ -1623,11 +1625,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1623,11 +1625,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize()); p_vgrad_grid, vgrad_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize()); p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize()); p_kgrad_grid, kgrad_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K] // divide block work by [N, K]
const auto block_work_idx = const auto block_work_idx =
...@@ -1800,7 +1802,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1800,7 +1802,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV: transform input and output tensor descriptors // dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock = auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_o0_n_o1); MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(vgrad_grid_desc_o0_n_o1);
// dK: transform input and output tensor descriptors // dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 = const auto q_grid_desc_m0_k_m1 =
...@@ -1833,7 +1835,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1833,7 +1835,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dK: transform input and output tensor descriptors // dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock = auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(k_grid_desc_k0_n_k1); MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(kgrad_grid_desc_k0_n_k1);
// //
// set up dQ Gemm (type 3 crr) // set up dQ Gemm (type 3 crr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment