Commit aea324d2 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'mha-train-develop' into mha-train-develop-grad-bias

parents 73611570 f04ec574
...@@ -5,12 +5,12 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 ...@@ -5,12 +5,12 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward_v1 grouped_multihead_attention_forward_v1.cpp) # add_example_executable(example_grouped_multihead_attention_forward_v1 grouped_multihead_attention_forward_v1.cpp)
add_example_executable(example_batched_multihead_attention_forward_v1 batched_multihead_attention_forward_v1.cpp) # add_example_executable(example_batched_multihead_attention_forward_v1 batched_multihead_attention_forward_v1.cpp)
add_example_executable(example_grouped_multihead_attention_backward_v1 grouped_multihead_attention_backward_v1.cpp) # add_example_executable(example_grouped_multihead_attention_backward_v1 grouped_multihead_attention_backward_v1.cpp)
add_example_executable(example_batched_multihead_attention_backward_v1 batched_multihead_attention_backward_v1.cpp) # add_example_executable(example_batched_multihead_attention_backward_v1 batched_multihead_attention_backward_v1.cpp)
add_example_executable(example_grouped_multihead_attention_train_v1 grouped_multihead_attention_train_v1.cpp) # add_example_executable(example_grouped_multihead_attention_train_v1 grouped_multihead_attention_train_v1.cpp)
add_example_executable(example_batched_multihead_attention_train_v1 batched_multihead_attention_train_v1.cpp) # add_example_executable(example_batched_multihead_attention_train_v1 batched_multihead_attention_train_v1.cpp)
add_example_executable(example_grouped_multihead_attention_forward_v2 grouped_multihead_attention_forward_v2.cpp) add_example_executable(example_grouped_multihead_attention_forward_v2 grouped_multihead_attention_forward_v2.cpp)
add_example_executable(example_batched_multihead_attention_forward_v2 batched_multihead_attention_forward_v2.cpp) add_example_executable(example_batched_multihead_attention_forward_v2 batched_multihead_attention_forward_v2.cpp)
add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_multihead_attention_backward_v2.cpp) add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_multihead_attention_backward_v2.cpp)
......
...@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -324,10 +324,10 @@ int run(int argc, char* argv[]) ...@@ -324,10 +324,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
...@@ -631,7 +631,7 @@ int run(int argc, char* argv[]) ...@@ -631,7 +631,7 @@ int run(int argc, char* argv[])
lse_g_m, lse_g_m,
p_drop_g_m_n, p_drop_g_m_n,
z_g_m_n, z_g_m_n,
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
...@@ -691,7 +691,7 @@ int run(int argc, char* argv[]) ...@@ -691,7 +691,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......
...@@ -218,7 +218,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -218,7 +218,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -250,7 +250,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -250,7 +250,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -325,10 +325,10 @@ int run(int argc, char* argv[]) ...@@ -325,10 +325,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
...@@ -637,7 +637,7 @@ int run(int argc, char* argv[]) ...@@ -637,7 +637,7 @@ int run(int argc, char* argv[])
lse_g_m, lse_g_m,
p_drop_g_m_n, p_drop_g_m_n,
z_g_m_n, z_g_m_n,
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
...@@ -697,7 +697,7 @@ int run(int argc, char* argv[]) ...@@ -697,7 +697,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......
...@@ -247,7 +247,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -247,7 +247,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -279,7 +279,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -279,7 +279,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -354,10 +354,10 @@ int run(int argc, char* argv[]) ...@@ -354,10 +354,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
...@@ -815,7 +815,7 @@ int run(int argc, char* argv[]) ...@@ -815,7 +815,7 @@ int run(int argc, char* argv[])
lse_g_m, lse_g_m,
p_drop_g_m_n, p_drop_g_m_n,
z_fwd_g_m_n, z_fwd_g_m_n,
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) { ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
...@@ -858,7 +858,7 @@ int run(int argc, char* argv[]) ...@@ -858,7 +858,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_bwd_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_bwd_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......
...@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -248,7 +248,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -248,7 +248,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -311,9 +311,9 @@ int run(int argc, char* argv[]) ...@@ -311,9 +311,9 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -690,7 +690,7 @@ int run(int argc, char* argv[]) ...@@ -690,7 +690,7 @@ int run(int argc, char* argv[])
lse_g_ms[i], lse_g_ms[i],
p_drop_g_m_ns[i], p_drop_g_m_ns[i],
z_g_m_ns[i], z_g_m_ns[i],
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) { y_tensors[i].ForEach([&](auto& self, auto idx) {
...@@ -742,7 +742,7 @@ int run(int argc, char* argv[]) ...@@ -742,7 +742,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
......
...@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -312,9 +312,9 @@ int run(int argc, char* argv[]) ...@@ -312,9 +312,9 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -703,7 +703,7 @@ int run(int argc, char* argv[]) ...@@ -703,7 +703,7 @@ int run(int argc, char* argv[])
lse_g_ms[i], lse_g_ms[i],
p_drop_g_m_ns[i], p_drop_g_m_ns[i],
z_g_m_ns[i], z_g_m_ns[i],
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) { y_tensors[i].ForEach([&](auto& self, auto idx) {
...@@ -755,7 +755,7 @@ int run(int argc, char* argv[]) ...@@ -755,7 +755,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
......
...@@ -246,7 +246,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -246,7 +246,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -278,7 +278,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -278,7 +278,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -341,9 +341,9 @@ int run(int argc, char* argv[]) ...@@ -341,9 +341,9 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
auto gemm_fwd = DeviceGemmInstanceFWD{}; auto gemm_fwd = DeviceGemmInstanceFWD{};
auto invoker_fwd = gemm_fwd.MakeInvoker(); auto invoker_fwd = gemm_fwd.MakeInvoker();
...@@ -860,7 +860,7 @@ int run(int argc, char* argv[]) ...@@ -860,7 +860,7 @@ int run(int argc, char* argv[])
lse_g_ms[i], lse_g_ms[i],
p_drop_g_m_ns[i], p_drop_g_m_ns[i],
z_fwd_g_m_ns[i], z_fwd_g_m_ns[i],
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
int G0 = v_tensors[i].GetLengths()[0]; int G0 = v_tensors[i].GetLengths()[0];
...@@ -893,7 +893,7 @@ int run(int argc, char* argv[]) ...@@ -893,7 +893,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_bwd_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_bwd_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
......
...@@ -66,10 +66,10 @@ int run(int argc, char* argv[]) ...@@ -66,10 +66,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
...@@ -159,6 +159,7 @@ int run(int argc, char* argv[]) ...@@ -159,6 +159,7 @@ int run(int argc, char* argv[])
a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -322,7 +323,7 @@ int run(int argc, char* argv[]) ...@@ -322,7 +323,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout); z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// gemm1 // gemm1
......
...@@ -43,9 +43,9 @@ int run(int argc, char* argv[]) ...@@ -43,9 +43,9 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1; // scaling after 1st gemm float alpha = 1; // scaling after 1st gemm
...@@ -217,6 +217,7 @@ int run(int argc, char* argv[]) ...@@ -217,6 +217,7 @@ int run(int argc, char* argv[])
a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data()); a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data());
b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data()); b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data());
b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data()); b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data());
z_tensors_device[i]->ToDevice(z_gs_ms_ns.mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
...@@ -390,7 +391,7 @@ int run(int argc, char* argv[]) ...@@ -390,7 +391,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout); z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// gemm 1 // gemm 1
......
...@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -252,7 +252,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -252,7 +252,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -327,10 +327,10 @@ int run(int argc, char* argv[]) ...@@ -327,10 +327,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
...@@ -659,7 +659,7 @@ int run(int argc, char* argv[]) ...@@ -659,7 +659,7 @@ int run(int argc, char* argv[])
lse_g_m, lse_g_m,
p_drop_g_m_n, p_drop_g_m_n,
z_g_m_n, z_g_m_n,
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
...@@ -719,7 +719,7 @@ int run(int argc, char* argv[]) ...@@ -719,7 +719,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......
...@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -251,7 +251,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -251,7 +251,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -314,9 +314,9 @@ int run(int argc, char* argv[]) ...@@ -314,9 +314,9 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -730,7 +730,7 @@ int run(int argc, char* argv[]) ...@@ -730,7 +730,7 @@ int run(int argc, char* argv[])
lse_g_ms[i], lse_g_ms[i],
p_drop_g_m_ns[i], p_drop_g_m_ns[i],
z_g_m_ns[i], z_g_m_ns[i],
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) { y_tensors[i].ForEach([&](auto& self, auto idx) {
...@@ -784,7 +784,7 @@ int run(int argc, char* argv[]) ...@@ -784,7 +784,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
......
...@@ -66,10 +66,10 @@ int run(int argc, char* argv[]) ...@@ -66,10 +66,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
...@@ -172,6 +172,7 @@ int run(int argc, char* argv[]) ...@@ -172,6 +172,7 @@ int run(int argc, char* argv[])
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -243,6 +244,18 @@ int run(int argc, char* argv[]) ...@@ -243,6 +244,18 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// data objects for hipGraph verification
hipGraph_t graph;
hipGraphExec_t g_instance;
hipStream_t stream;
std::cout << "verification with hipGraph capturing and replaying ... " << std::endl;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
HIP_CHECK_ERROR(hipGraphCreate(&graph, 0));
HIP_CHECK_ERROR(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
// run for storing z tensor // run for storing z tensor
argument = argument =
gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
...@@ -276,9 +289,19 @@ int run(int argc, char* argv[]) ...@@ -276,9 +289,19 @@ int run(int argc, char* argv[])
p_drop, // dropout ratio p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
c_device_buf.SetZero(); HIP_CHECK_ERROR(hipMemsetAsync(
lse_device_buf.SetZero(); c_device_buf.GetDeviceBuffer(), 0, c_device_buf.GetBufferSize(), stream));
invoker.Run(argument, StreamConfig{nullptr, false}); HIP_CHECK_ERROR(hipMemsetAsync(
lse_device_buf.GetDeviceBuffer(), 0, lse_device_buf.GetBufferSize(), stream));
invoker.Run(argument, StreamConfig{stream, false});
HIP_CHECK_ERROR(hipStreamEndCapture(stream, &graph));
HIP_CHECK_ERROR(hipGraphInstantiate(&g_instance, graph, nullptr, nullptr, 0));
HIP_CHECK_ERROR(hipGraphLaunch(g_instance, stream));
HIP_CHECK_ERROR(hipStreamSynchronize(stream));
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_device_buf.FromDevice(z_gs_ms_ns.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
...@@ -322,7 +345,9 @@ int run(int argc, char* argv[]) ...@@ -322,7 +345,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias // bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); }); acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx));
});
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
...@@ -342,7 +367,7 @@ int run(int argc, char* argv[]) ...@@ -342,7 +367,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout); z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// gemm1 // gemm1
......
...@@ -43,9 +43,9 @@ int run(int argc, char* argv[]) ...@@ -43,9 +43,9 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1; // scaling after 1st gemm float alpha = 1; // scaling after 1st gemm
...@@ -149,8 +149,8 @@ int run(int argc, char* argv[]) ...@@ -149,8 +149,8 @@ int run(int argc, char* argv[])
lse_gs_ms_strides, lse_gs_ms_strides,
d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_biases_gs_ms_os_lengths
{}}); // acc1_biases_gs_ms_os_strides {}}); // acc1_biases_gs_ms_os_strides
// C_m_o = A_m_k * B0_k_n * B1_n_o // C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
...@@ -163,10 +163,11 @@ int run(int argc, char* argv[]) ...@@ -163,10 +163,11 @@ int run(int argc, char* argv[])
int Batch = G0 * G1; int Batch = G0 * G1;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte +=
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) * sizeof(CDataType) * M * O +
Batch; sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
Batch;
if(i < 4) if(i < 4)
{ {
...@@ -237,6 +238,7 @@ int run(int argc, char* argv[]) ...@@ -237,6 +238,7 @@ int run(int argc, char* argv[])
b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data()); b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data());
b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data()); b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data());
d_tensors_device[i]->ToDevice(d_gs_ms_ns.mData.data()); d_tensors_device[i]->ToDevice(d_gs_ms_ns.mData.data());
z_tensors_device[i]->ToDevice(z_gs_ms_ns.mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
...@@ -301,6 +303,18 @@ int run(int argc, char* argv[]) ...@@ -301,6 +303,18 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
// data objects for hipGraph verification
hipGraph_t graph;
hipGraphExec_t g_instance;
hipStream_t stream;
std::cout << "verification with hipGraph capturing and replaying ... " << std::endl;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
HIP_CHECK_ERROR(hipGraphCreate(&graph, 0));
HIP_CHECK_ERROR(hipStreamBeginCapture(stream, hipStreamCaptureModeRelaxed));
argument = argument =
gemm.MakeArgument(p_a, gemm.MakeArgument(p_a,
p_b0, p_b0,
...@@ -324,7 +338,16 @@ int run(int argc, char* argv[]) ...@@ -324,7 +338,16 @@ int run(int argc, char* argv[])
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{stream, false});
HIP_CHECK_ERROR(hipStreamEndCapture(stream, &graph));
HIP_CHECK_ERROR(hipGraphInstantiate(&g_instance, graph, nullptr, nullptr, 0));
HIP_CHECK_ERROR(hipGraphDebugDotPrint(graph, "grouped_fwd_debug.dot", 0x007f));
HIP_CHECK_ERROR(hipGraphLaunch(g_instance, stream));
HIP_CHECK_ERROR(hipStreamSynchronize(stream));
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
...@@ -396,7 +419,9 @@ int run(int argc, char* argv[]) ...@@ -396,7 +419,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias // bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); }); acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx));
});
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
...@@ -419,7 +444,7 @@ int run(int argc, char* argv[]) ...@@ -419,7 +444,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout); z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// gemm 1 // gemm 1
......
...@@ -15,3 +15,16 @@ inline void hip_check_error(hipError_t x) ...@@ -15,3 +15,16 @@ inline void hip_check_error(hipError_t x)
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
} }
#define HIP_CHECK_ERROR(flag) \
do \
{ \
hipError_t _tmpVal; \
if((_tmpVal = flag) != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
<< hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
...@@ -16,111 +16,111 @@ struct BlockwiseDropout ...@@ -16,111 +16,111 @@ struct BlockwiseDropout
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer, bool using_sign_bit = false> // template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph) // __host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph)
{ // {
auto execute_dropout = [&](bool keep, DataType val) { // auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit) // if constexpr(using_sign_bit)
return keep ? val : -val; // return keep ? val : -val;
else // else
return keep ? val * p_dropout_rescale : float(0); // return keep ? val * p_dropout_rescale : float(0);
}; // };
constexpr int tmp_size = MRepeat * KRepeat; // constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8; // int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; // ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ // {
ph.get_random_8x16((tmp + i * 8)); // ph.get_random_8x16((tmp + i * 8));
} // }
block_sync_lds(); // block_sync_lds();
int tmp_index = 0; // int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) { // static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { // static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; // auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM,
in_thread_buf(offset) = // iK))>{}; in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset)); // execute_dropout(tmp[tmp_index] <= p_dropout_uint8_t, in_thread_buf(offset));
tmp_index = tmp_index + 1; // tmp_index = tmp_index + 1;
}); // });
}); // });
} // }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false> // template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void // __host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf) // ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{ // {
auto execute_dropout = [&](bool keep, DataType val) { // auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit) // if constexpr(using_sign_bit)
return keep ? val : -val; // return keep ? val : -val;
else // else
return keep ? val * p_dropout_rescale : float(0); // return keep ? val * p_dropout_rescale : float(0);
}; // };
constexpr int tmp_size = MRepeat * KRepeat; // constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8; // int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; // ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ // {
ph.get_random_8x16((tmp + i * 8)); // ph.get_random_8x16((tmp + i * 8));
} // }
block_sync_lds(); // block_sync_lds();
int tmp_index = 0; // int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) { // static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { // static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; // auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM,
in_thread_buf(offset) = // iK))>{}; in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset)); // execute_dropout(tmp[tmp_index] <= p_dropout_uint8_t, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index]; // z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; // tmp_index = tmp_index + 1;
}); // });
}); // });
} // }
template <typename CThreadBuffer, // template <typename CThreadBuffer,
typename ZThreadBuffer, // typename ZThreadBuffer,
bool using_sign_bit, // bool using_sign_bit,
typename N0, // typename N0,
typename Offset> // typename Offset>
__host__ __device__ void // __host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf) // ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{ // {
auto execute_dropout = [&](bool keep, DataType val) { // auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit) // if constexpr(using_sign_bit)
return keep ? val : -val; // return keep ? val : -val;
else // else
return keep ? val * p_dropout_rescale : float(0); // return keep ? val * p_dropout_rescale : float(0);
}; // };
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value; // constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
int philox_calls = tmp_size / 8; // int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; // ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ // {
ph.get_random_8x16((tmp + i * 8)); // ph.get_random_8x16((tmp + i * 8));
} // }
block_sync_lds(); // block_sync_lds();
constexpr auto iOffset = Number<tmp_size>{} * Offset{}; // constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) { // static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) = // in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset)); // execute_dropout(tmp[i.value] <= p_dropout_uint8_t, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp[i.value]; // z_thread_buf(i) = tmp[i.value];
}); // });
} // }
template <typename CThreadBuffer, typename Offset, bool using_sign_bit = false> template <typename CThreadBuffer, typename Offset, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf, __host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf,
...@@ -138,12 +138,12 @@ struct BlockwiseDropout ...@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8; int philox_calls = tmp_size / 16;
ushort tmp[tmp_size]; uint8_t tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw); ph.get_random_16x8((tmp + i * 16), element_global_1d_id + i * Offset{} * MRaw);
} }
block_sync_lds(); block_sync_lds();
...@@ -153,7 +153,7 @@ struct BlockwiseDropout ...@@ -153,7 +153,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] <= p_dropout_uint8_t, in_thread_buf(offset));
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
...@@ -179,12 +179,12 @@ struct BlockwiseDropout ...@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8; int philox_calls = tmp_size / 16;
ushort tmp[tmp_size]; uint8_t tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw); ph.get_random_16x8((tmp + i * 16), element_global_1d_id + i * Offset{} * MRaw);
} }
block_sync_lds(); block_sync_lds();
...@@ -194,7 +194,7 @@ struct BlockwiseDropout ...@@ -194,7 +194,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] <= p_dropout_uint8_t, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
...@@ -213,7 +213,7 @@ struct BlockwiseDropout ...@@ -213,7 +213,7 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value; constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
static_for<0, tmp_size, 1>{}([&](auto i) { static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + Offset{}) = in_thread_buf(i + Offset{}) =
execute_dropout(z_thread_buf(i) <= p_dropout_16bits, in_thread_buf(i + Offset{})); execute_dropout(z_thread_buf(i) <= p_dropout_uint8_t, in_thread_buf(i + Offset{}));
}); });
} }
...@@ -225,18 +225,18 @@ struct BlockwiseDropout ...@@ -225,18 +225,18 @@ struct BlockwiseDropout
{ {
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value; constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 8; int philox_calls = tmp_size / 16;
ushort tmp[tmp_size]; uint8_t tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{}); ph.get_random_16x8((tmp + i * 16), element_global_1d_id + i * Offset{});
} }
static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; }); static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
} }
ushort p_dropout_16bits; uint8_t p_dropout_uint8_t;
DataType p_dropout_rescale; DataType p_dropout_rescale;
}; };
......
...@@ -69,7 +69,6 @@ __global__ void ...@@ -69,7 +69,6 @@ __global__ void
raw_n_padded); raw_n_padded);
#else #else
ignore = p_z_grid; ignore = p_z_grid;
ignore = a_grid_desc_ak0_m_ak1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
......
...@@ -40,7 +40,7 @@ template <typename GridwiseGemm, ...@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
typename D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
...@@ -73,15 +73,15 @@ __global__ void ...@@ -73,15 +73,15 @@ __global__ void
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t mblock, const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits, const uint8_t p_dropout_in_uint8_t,
const GemmAccDataType p_dropout_rescale, const GemmAccDataType p_dropout_rescale,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset, const unsigned long long offset,
...@@ -145,11 +145,11 @@ __global__ void ...@@ -145,11 +145,11 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_dropout_in_16bits, p_dropout_in_uint8_t,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
z_random_matrix_offset, z_random_matrix_offset,
...@@ -178,11 +178,11 @@ __global__ void ...@@ -178,11 +178,11 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_dropout_in_16bits, p_dropout_in_uint8_t,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
z_random_matrix_offset, z_random_matrix_offset,
...@@ -207,14 +207,14 @@ __global__ void ...@@ -207,14 +207,14 @@ __global__ void
ignore = d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; ignore = d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6; ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = mblock; ignore = mblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
ignore = p_dropout_in_16bits; ignore = p_dropout_in_uint8_t;
ignore = p_dropout_rescale; ignore = p_dropout_rescale;
ignore = seed; ignore = seed;
ignore = offset; ignore = offset;
...@@ -695,18 +695,17 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -695,18 +695,17 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
} }
} }
is_dropout_ = p_dropout > 0.0; // is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_uint8_t_ = uint8_t(std::floor(p_dropout_ * 255.0));
p_dropout_ = 1.f / p_dropout_; p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_); p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
z_grid_desc_m_n_);
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]); m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]); n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
...@@ -779,8 +778,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -779,8 +778,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -806,7 +805,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -806,7 +805,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_; float p_dropout_;
ushort p_dropout_in_16bits_; uint8_t p_dropout_in_uint8_t_;
GemmAccDataType p_dropout_rescale_; GemmAccDataType p_dropout_rescale_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
...@@ -864,7 +863,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -864,7 +863,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
...@@ -897,14 +896,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -897,14 +896,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_), arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_, arg.p_dropout_in_uint8_t_,
arg.p_dropout_rescale_, arg.p_dropout_rescale_,
arg.seed_, arg.seed_,
arg.offset_, arg.offset_,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <cstring>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp" #include "ck/utility/philox_rand.hpp"
...@@ -48,7 +49,7 @@ __global__ void ...@@ -48,7 +49,7 @@ __global__ void
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const ushort p_dropout_in_16bits, const uint8_t p_dropout_in_uint8_t,
const GemmAccDataType p_dropout_rescale, const GemmAccDataType p_dropout_rescale,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
...@@ -140,11 +141,11 @@ __global__ void ...@@ -140,11 +141,11 @@ __global__ void
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits, p_dropout_in_uint8_t,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
arg_ptr[group_id].z_random_matrix_offset_ + arg_ptr[group_id].z_random_matrix_offset_ +
...@@ -178,11 +179,11 @@ __global__ void ...@@ -178,11 +179,11 @@ __global__ void
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits, p_dropout_in_uint8_t,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
arg_ptr[group_id].z_random_matrix_offset_ + arg_ptr[group_id].z_random_matrix_offset_ +
...@@ -198,7 +199,7 @@ __global__ void ...@@ -198,7 +199,7 @@ __global__ void
ignore = acc_element_op; ignore = acc_element_op;
ignore = b1_element_op; ignore = b1_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = p_dropout_in_16bits; ignore = p_dropout_in_uint8_t;
ignore = p_dropout_rescale; ignore = p_dropout_rescale;
ignore = seed; ignore = seed;
ignore = offset; ignore = offset;
...@@ -620,8 +621,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -620,8 +621,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
...@@ -774,8 +775,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -774,8 +775,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n); c_grid_desc_m_n);
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 = const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
z_grid_desc_m_n); z_grid_desc_m_n);
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
...@@ -819,7 +820,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -819,7 +820,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m_n, z_grid_desc_m_n,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n), block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
...@@ -857,11 +858,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -857,11 +858,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_n_length_stride}); d0_n_length_stride});
} }
use_dropout_ = p_dropout > 0.0; // use_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_uint8_t_ = uint8_t(std::floor(p_dropout_ * 255.0));
p_dropout_ = 1.f / p_dropout_; p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_); p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -880,7 +881,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -880,7 +881,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
float p_dropout_; float p_dropout_;
ushort p_dropout_in_16bits_; uint8_t p_dropout_in_uint8_t_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
GemmAccDataType p_dropout_rescale_; GemmAccDataType p_dropout_rescale_;
...@@ -912,10 +913,34 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -912,10 +913,34 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
} }
hipGetErrorString(hipMemcpy(arg.p_workspace_, hipStreamCaptureStatus status = hipStreamCaptureStatusNone;
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg), HIP_CHECK_ERROR(hipStreamIsCapturing(stream_config.stream_id_, &status));
hipMemcpyHostToDevice));
if(status == hipStreamCaptureStatusActive)
{
size_t copy_size = arg.group_kernel_args_.size() * sizeof(GroupKernelArg);
// ToDO: when to release this memory buffer?
char* persistent_ptr = new char[copy_size];
(void)std::memcpy(persistent_ptr, arg.group_kernel_args_.data(), copy_size);
HIP_CHECK_ERROR(hipMemcpyAsync(arg.p_workspace_,
persistent_ptr,
copy_size,
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
else
{
HIP_CHECK_ERROR(
hipMemcpyAsync(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
float ave_time = 0; float ave_time = 0;
...@@ -949,7 +974,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -949,7 +974,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.acc_element_op_, arg.acc_element_op_,
arg.b1_element_op_, arg.b1_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.p_dropout_in_16bits_, arg.p_dropout_in_uint8_t_,
arg.p_dropout_rescale_, arg.p_dropout_rescale_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_);
......
...@@ -57,8 +57,8 @@ struct GridwiseBatchedDropout ...@@ -57,8 +57,8 @@ struct GridwiseBatchedDropout
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2 static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time // get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16 static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -241,7 +241,7 @@ struct GridwiseBatchedDropout ...@@ -241,7 +241,7 @@ struct GridwiseBatchedDropout
// only used for providing ApplyDropoutAttnBwdSaveZ // only used for providing ApplyDropoutAttnBwdSaveZ
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{ auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
static_cast<unsigned short>(0.8f * 65535.f), static_cast<FloatGemmAcc>(1.0f / 0.8f)}; static_cast<unsigned short>(0.8f * 255.f), static_cast<FloatGemmAcc>(1.0f / 0.8f)};
// //
// z vgpr copy to global // z vgpr copy to global
...@@ -260,7 +260,7 @@ struct GridwiseBatchedDropout ...@@ -260,7 +260,7 @@ struct GridwiseBatchedDropout
n2)); // NPerXdl n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ushort, uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tensor_buffer; z_tensor_buffer;
...@@ -273,7 +273,7 @@ struct GridwiseBatchedDropout ...@@ -273,7 +273,7 @@ struct GridwiseBatchedDropout
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, uint8_t,
ZDataType, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
...@@ -99,8 +99,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -99,8 +99,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
...@@ -120,8 +118,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -120,8 +118,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto V_K0 = KPerBlock / V_K1 / V_K2; static constexpr auto V_K0 = KPerBlock / V_K1 / V_K2;
static constexpr auto V_N1 = NXdlPerWave; static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2 static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time // get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16 static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -1450,8 +1448,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1450,8 +1448,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = const uint8_t p_dropout_in_uint8_t =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0)); __builtin_amdgcn_readfirstlane(uint8_t(std::floor(p_dropout * 255.0)));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout); rp_dropout);
...@@ -1769,7 +1767,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1769,7 +1767,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{ auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout}; p_dropout_in_uint8_t, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 = auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m); MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
...@@ -1838,7 +1836,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1838,7 +1836,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
n2)); // NPerXdl n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ushort, uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tensor_buffer; z_tensor_buffer;
...@@ -1848,7 +1846,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1848,7 +1846,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, uint8_t,
ZDataType, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment