Commit 0353c29e authored by danyao12's avatar danyao12
Browse files

uint8 dropout

parent b7b7e153
......@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
ZDataType p_dropout_in_uint8_t,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
......@@ -325,7 +325,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
......@@ -627,7 +627,7 @@ int run(int argc, char* argv[])
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
p_dropout_in_uint8_t,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
......@@ -687,7 +687,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......
......@@ -218,7 +218,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
ZDataType p_dropout_in_uint8_t,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -250,7 +250,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
......@@ -326,7 +326,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
......@@ -633,7 +633,7 @@ int run(int argc, char* argv[])
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
p_dropout_in_uint8_t,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
......@@ -693,7 +693,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......
......@@ -247,7 +247,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
ZDataType p_dropout_in_uint8_t,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -279,7 +279,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
......@@ -355,7 +355,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
......@@ -811,7 +811,7 @@ int run(int argc, char* argv[])
lse_g_m,
p_drop_g_m_n,
z_fwd_g_m_n,
p_dropout_in_16bits,
p_dropout_in_uint8_t,
rp_dropout);
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
......@@ -854,7 +854,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_bwd_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
z_bwd_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......
......@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
ZDataType p_dropout_in_uint8_t,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -248,7 +248,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
......@@ -312,7 +312,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
auto gemm = DeviceGemmInstance{};
......@@ -686,7 +686,7 @@ int run(int argc, char* argv[])
lse_g_ms[i],
p_drop_g_m_ns[i],
z_g_m_ns[i],
p_dropout_in_16bits,
p_dropout_in_uint8_t,
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
......@@ -738,7 +738,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
......
......@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
ZDataType p_dropout_in_uint8_t,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
......@@ -313,7 +313,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
auto gemm = DeviceGemmInstance{};
......@@ -699,7 +699,7 @@ int run(int argc, char* argv[])
lse_g_ms[i],
p_drop_g_m_ns[i],
z_g_m_ns[i],
p_dropout_in_16bits,
p_dropout_in_uint8_t,
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
......@@ -751,7 +751,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
......
......@@ -246,7 +246,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
ZDataType p_dropout_in_uint8_t,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -278,7 +278,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
......@@ -342,7 +342,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
auto gemm_fwd = DeviceGemmInstanceFWD{};
......@@ -856,7 +856,7 @@ int run(int argc, char* argv[])
lse_g_ms[i],
p_drop_g_m_ns[i],
z_fwd_g_m_ns[i],
p_dropout_in_16bits,
p_dropout_in_uint8_t,
rp_dropout);
int G0 = v_tensors[i].GetLengths()[0];
......@@ -889,7 +889,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_bwd_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
z_bwd_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
......
......@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
......@@ -159,6 +159,7 @@ int run(int argc, char* argv[])
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
......@@ -322,7 +323,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout);
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// gemm1
......
......@@ -44,7 +44,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1; // scaling after 1st gemm
......@@ -217,6 +217,7 @@ int run(int argc, char* argv[])
a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data());
b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data());
b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data());
z_tensors_device[i]->ToDevice(z_gs_ms_ns.mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
......@@ -390,7 +391,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout);
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// gemm 1
......
......@@ -16,111 +16,111 @@ struct BlockwiseDropout
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer,
typename ZThreadBuffer,
bool using_sign_bit,
typename N0,
typename Offset>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp[i.value];
});
}
// template <typename CThreadBuffer, bool using_sign_bit = false>
// __host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph)
// {
// auto execute_dropout = [&](bool keep, DataType val) {
// if constexpr(using_sign_bit)
// return keep ? val : -val;
// else
// return keep ? val * p_dropout_rescale : float(0);
// };
// constexpr int tmp_size = MRepeat * KRepeat;
// int philox_calls = tmp_size / 8;
// ushort tmp[tmp_size];
// for(int i = 0; i < philox_calls; i++)
// {
// ph.get_random_8x16((tmp + i * 8));
// }
// block_sync_lds();
// int tmp_index = 0;
// static_for<0, MRepeat, 1>{}([&](auto iM) {
// static_for<0, KRepeat, 1>{}([&](auto iK) {
// auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM,
// iK))>{}; in_thread_buf(offset) =
// execute_dropout(tmp[tmp_index] <= p_dropout_uint8_t, in_thread_buf(offset));
// tmp_index = tmp_index + 1;
// });
// });
// }
// template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
// __host__ __device__ void
// ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
// {
// auto execute_dropout = [&](bool keep, DataType val) {
// if constexpr(using_sign_bit)
// return keep ? val : -val;
// else
// return keep ? val * p_dropout_rescale : float(0);
// };
// constexpr int tmp_size = MRepeat * KRepeat;
// int philox_calls = tmp_size / 8;
// ushort tmp[tmp_size];
// for(int i = 0; i < philox_calls; i++)
// {
// ph.get_random_8x16((tmp + i * 8));
// }
// block_sync_lds();
// int tmp_index = 0;
// static_for<0, MRepeat, 1>{}([&](auto iM) {
// static_for<0, KRepeat, 1>{}([&](auto iK) {
// auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM,
// iK))>{}; in_thread_buf(offset) =
// execute_dropout(tmp[tmp_index] <= p_dropout_uint8_t, in_thread_buf(offset));
// z_thread_buf(offset) = tmp[tmp_index];
// tmp_index = tmp_index + 1;
// });
// });
// }
// template <typename CThreadBuffer,
// typename ZThreadBuffer,
// bool using_sign_bit,
// typename N0,
// typename Offset>
// __host__ __device__ void
// ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
// {
// auto execute_dropout = [&](bool keep, DataType val) {
// if constexpr(using_sign_bit)
// return keep ? val : -val;
// else
// return keep ? val * p_dropout_rescale : float(0);
// };
// constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
// int philox_calls = tmp_size / 8;
// ushort tmp[tmp_size];
// for(int i = 0; i < philox_calls; i++)
// {
// ph.get_random_8x16((tmp + i * 8));
// }
// block_sync_lds();
// constexpr auto iOffset = Number<tmp_size>{} * Offset{};
// static_for<0, tmp_size, 1>{}([&](auto i) {
// in_thread_buf(i + iOffset) =
// execute_dropout(tmp[i.value] <= p_dropout_uint8_t, in_thread_buf(i + iOffset));
// z_thread_buf(i) = tmp[i.value];
// });
// }
template <typename CThreadBuffer, typename Offset, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf,
......@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
int philox_calls = tmp_size / 16;
ushort tmp[tmp_size];
uint8_t tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
ph.get_random_16x8((tmp + i * 16), element_global_1d_id + i * Offset{} * MRaw);
}
block_sync_lds();
......@@ -153,7 +153,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
execute_dropout(tmp[tmp_index] <= p_dropout_uint8_t, in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
});
......@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
int philox_calls = tmp_size / 16;
ushort tmp[tmp_size];
uint8_t tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
ph.get_random_16x8((tmp + i * 16), element_global_1d_id + i * Offset{} * MRaw);
}
block_sync_lds();
......@@ -194,7 +194,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
execute_dropout(tmp[tmp_index] <= p_dropout_uint8_t, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
......@@ -213,7 +213,7 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + Offset{}) =
execute_dropout(z_thread_buf(i) <= p_dropout_16bits, in_thread_buf(i + Offset{}));
execute_dropout(z_thread_buf(i) <= p_dropout_uint8_t, in_thread_buf(i + Offset{}));
});
}
......@@ -225,18 +225,18 @@ struct BlockwiseDropout
{
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 8;
int philox_calls = tmp_size / 16;
ushort tmp[tmp_size];
uint8_t tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{});
ph.get_random_16x8((tmp + i * 16), element_global_1d_id + i * Offset{});
}
static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
}
ushort p_dropout_16bits;
uint8_t p_dropout_uint8_t;
DataType p_dropout_rescale;
};
......
......@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
typename D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename LSEGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
......@@ -73,15 +73,15 @@ __global__ void
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
const uint8_t p_dropout_in_uint8_t,
const GemmAccDataType p_dropout_rescale,
const unsigned long long seed,
const unsigned long long offset,
......@@ -145,11 +145,11 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
z_random_matrix_offset,
......@@ -178,11 +178,11 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
z_random_matrix_offset,
......@@ -207,14 +207,14 @@ __global__ void
ignore = d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = lse_grid_desc_m;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = mblock;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_dropout_in_16bits;
ignore = p_dropout_in_uint8_t;
ignore = p_dropout_rescale;
ignore = seed;
ignore = offset;
......@@ -697,16 +697,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_in_uint8_t_ = uint8_t(std::floor(p_dropout_ * 255.0));
p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
z_grid_desc_m_n_);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
......@@ -779,8 +778,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -806,7 +805,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
ushort p_dropout_in_16bits_;
uint8_t p_dropout_in_uint8_t_;
GemmAccDataType p_dropout_rescale_;
unsigned long long seed_;
unsigned long long offset_;
......@@ -864,7 +863,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
......@@ -897,14 +896,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.lse_grid_desc_m_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.p_dropout_in_uint8_t_,
arg.p_dropout_rescale_,
arg.seed_,
arg.offset_,
......
......@@ -48,7 +48,7 @@ __global__ void
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const ushort p_dropout_in_16bits,
const uint8_t p_dropout_in_uint8_t,
const GemmAccDataType p_dropout_rescale,
const unsigned long long seed,
const unsigned long long offset)
......@@ -140,11 +140,11 @@ __global__ void
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
......@@ -178,11 +178,11 @@ __global__ void
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
......@@ -198,7 +198,7 @@ __global__ void
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = p_dropout_in_16bits;
ignore = p_dropout_in_uint8_t;
ignore = p_dropout_rescale;
ignore = seed;
ignore = offset;
......@@ -620,8 +620,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_;
......@@ -774,8 +774,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n);
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
z_grid_desc_m_n);
const index_t BlockStart = grid_size_;
......@@ -819,7 +819,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m_n,
lse_grid_desc_m,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
......@@ -859,7 +859,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
use_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_in_uint8_t_ = uint8_t(std::floor(p_dropout_ * 255.0));
p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
......@@ -880,7 +880,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation c_element_op_;
float p_dropout_;
ushort p_dropout_in_16bits_;
uint8_t p_dropout_in_uint8_t_;
unsigned long long seed_;
unsigned long long offset_;
GemmAccDataType p_dropout_rescale_;
......@@ -949,7 +949,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.p_dropout_in_16bits_,
arg.p_dropout_in_uint8_t_,
arg.p_dropout_rescale_,
arg.seed_,
arg.offset_);
......
......@@ -120,8 +120,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto V_K0 = KPerBlock / V_K1 / V_K2;
static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
// get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -1409,8 +1409,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const uint8_t p_dropout_in_uint8_t =
__builtin_amdgcn_readfirstlane(uint8_t(std::floor(p_dropout * 255.0)));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
......@@ -1726,7 +1726,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout};
p_dropout_in_uint8_t, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
......@@ -1795,7 +1795,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tensor_buffer;
......@@ -1805,7 +1805,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
uint8_t,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
......@@ -133,8 +133,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto V_K0 = Gemm1NPerBlock / KPerBlock;
static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
// get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -1506,8 +1506,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const uint8_t p_dropout_in_uint8_t =
__builtin_amdgcn_readfirstlane(uint8_t(std::floor(p_dropout * 255.0)));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
......@@ -1848,7 +1848,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout};
p_dropout_in_uint8_t, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
......@@ -1917,7 +1917,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tensor_buffer;
......@@ -1927,7 +1927,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
uint8_t,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
......@@ -119,8 +119,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto V_K0 = KPerBlock / V_K1 / V_K2;
static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
// get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -1492,8 +1492,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const uint8_t p_dropout_in_uint8_t =
__builtin_amdgcn_readfirstlane(uint8_t(std::floor(p_dropout * 255.0)));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
......@@ -1809,7 +1809,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout};
p_dropout_in_uint8_t, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
......@@ -1859,7 +1859,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tensor_buffer;
......@@ -1869,7 +1869,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
uint8_t,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
......@@ -132,8 +132,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto V_K0 = Gemm1NPerBlock / KPerBlock;
static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
// get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -1478,8 +1478,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset = k_block_space_size_aligned.value *
sizeof(GemmDataType) /
static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
......@@ -1564,8 +1564,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const uint8_t p_dropout_in_uint8_t =
__builtin_amdgcn_readfirstlane(uint8_t(std::floor(p_dropout * 255.0)));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
......@@ -1906,7 +1906,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout};
p_dropout_in_uint8_t, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
......@@ -1956,7 +1956,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tensor_buffer;
......@@ -1966,7 +1966,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
uint8_t,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
......@@ -113,8 +113,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64;
......@@ -134,17 +132,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
static constexpr auto DropoutMThread = DropoutTile; // 16
static constexpr auto DropoutTilePerXdl = NPerXdl / DropoutTile; // 2
static constexpr auto DropoutStep = Number<DropoutStepValue>{}; // 1 2 4
static constexpr auto DropoutNRepeat =
Number<math::integer_divide_ceil(DropoutStep, DropoutTilePerXdl)>{}; // 1 1 2
static constexpr auto DropoutGroupPerTile =
Number<mfma.num_groups_per_blk / DropoutTilePerXdl>{}; // 2
static constexpr auto DropoutStepPerXdl =
Number<math::min(DropoutStep, DropoutTilePerXdl)>{}; // 1 2 2
// get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
static constexpr auto DropoutStep = Number<DropoutStepValue>{}; // 1 2
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -152,51 +142,45 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in gridwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
__host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
const auto M0 = M / MPerBlock;
const auto N0 = N / (DropoutNRepeat * NPerXdl);
const auto N0 = N / (DropoutStep * NPerXdl);
constexpr auto M1 = MXdlPerWave;
constexpr auto N1 = DropoutNRepeat;
constexpr auto N1 = DropoutStep;
constexpr auto M2 = Gemm0MWaves;
constexpr auto N2 = Gemm0NWaves;
constexpr auto M3 = DropoutTilePerXdl;
constexpr auto N3 = DropoutStepPerXdl;
constexpr auto M4 = DropoutTile;
constexpr auto N4 = DropoutGroupPerTile;
constexpr auto N5 = mfma.num_input_blks;
constexpr auto N6 = mfma.group_size;
constexpr auto M3 = DropoutTile;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4, N5, N6))),
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 8>{}, Sequence<1, 3, 5, 7, 9, 10, 11>{}));
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__host__ __device__ static constexpr auto
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5()
__host__ __device__ static constexpr auto GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto M0 = MXdlPerWave;
constexpr auto N0 = DropoutNRepeat;
constexpr auto N0 = DropoutStep;
constexpr auto M1 = Gemm0MWaves;
constexpr auto N1 = Gemm0NWaves;
constexpr auto M2 = DropoutTilePerXdl;
constexpr auto N2 = DropoutStepPerXdl;
constexpr auto M3 = DropoutTile;
constexpr auto N3 = DropoutGroupPerTile;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
constexpr auto M2 = DropoutTile;
constexpr auto N2 = mfma.num_groups_per_blk;
constexpr auto N3 = mfma.num_input_blks;
constexpr auto N4 = mfma.group_size;
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(M0, N0, M1, N1, M2, N2, M3, N3, N4, N5));
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
make_naive_tensor_descriptor_packed(make_tuple(M0, N0, M1, N1, M2, N2, N3, N4));
return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
}
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
......@@ -317,7 +301,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
const index_t z_block_bytes_end =
SharedMemTrait::z_shuffle_block_space_size * sizeof(ushort);
SharedMemTrait::z_shuffle_block_space_size * sizeof(uint8_t);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
......@@ -468,8 +452,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(ZGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
struct SharedMemTrait
{
......@@ -507,10 +491,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
// LDS allocation for Z shuffle in LDS
static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5();
static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
static constexpr auto z_shuffle_block_space_size =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize();
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop,
......@@ -538,12 +522,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
const uint8_t p_dropout_in_uint8_t,
FloatGemmAcc p_dropout_rescale,
ck::philox& ph,
const index_t z_random_matrix_offset,
......@@ -894,7 +878,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, p_dropout_rescale};
p_dropout_in_uint8_t, p_dropout_rescale};
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
......@@ -992,26 +976,22 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
wave_m_n_id[I0], // NInputIndex
0)); // register number
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // for blockwise copy
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
DropoutNRepeat, // NRepeat
DropoutStep, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
m2,
DropoutGroupPerTile,
n2,
n3,
n4)); // RegisterNum
constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = // for blockwise copy
constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
DropoutNRepeat, // NRepeat
DropoutStep, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
DropoutGroupPerTile,
n2,
n3,
n4, // RegisterNum
m2));
......@@ -1020,176 +1000,146 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 =
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId
m0, // MRepeat
DropoutNRepeat, // NRepeat
DropoutStep, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
m2,
DropoutGroupPerTile,
n2,
n3,
n4)); // RegisterNum
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5();
constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I0); // 1
constexpr auto ZN0 =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I1); // 1 1 2
constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I2); // 4
constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I3); // 1
constexpr auto ZM2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I4); // 2
constexpr auto ZN2 =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I5); // 1 2 2
constexpr auto ZM3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I6); // 16
constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I7); // 2
constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I8); // 2
constexpr auto ZN5 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I9); // 4
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
transform_tensor_descriptor(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); // 1
constexpr auto ZN0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); // 1 2
constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); // 4
constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); // 1
constexpr auto ZN2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); // 4
constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); // 2
constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); // 4
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(ZM0),
make_pass_through_transform(ZN0),
make_pass_through_transform(ZM1),
make_pass_through_transform(ZN1),
make_pass_through_transform(ZM2),
make_pass_through_transform(ZN2),
make_unmerge_transform(make_tuple(ZM3 / ZN4 / ZN5, ZN4, ZN5)),
make_merge_transform_v3_division_mod(make_tuple(ZN3, ZN4, ZN5))),
make_unmerge_transform(make_tuple(ZN2, ZN3, ZN4)),
make_merge_transform_v3_division_mod(make_tuple(ZN2, ZN3, ZN4))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7, 8, 9>{}),
Sequence<5, 6, 7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6, 7, 8>{},
Sequence<9>{}));
Sequence<4, 5, 6>{},
Sequence<7>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
uint8_t,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(),
true>
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6.GetElementSpaceSize());
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ushort*>(p_shared),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_tmp_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
static_cast<uint8_t*>(p_shared),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
auto z_tmp_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<uint8_t,
uint8_t,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<m0, // MRepeat
DropoutNRepeat, // NRepeat
DropoutStep, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
m2,
DropoutGroupPerTile,
n2,
n3,
n4>, // RegisterNum
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1] / DropoutMThread,
0,
wave_m_n_id[I1] % DropoutMThread,
wave_m_n_id[I1],
0,
wave_m_n_id[I0],
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ushort,
ushort,
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
Sequence<m0,
DropoutNRepeat,
m1,
n1,
I1,
DropoutStepPerXdl,
DropoutGroupPerTile,
n3,
n4,
m2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
auto z_shuffle_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<uint8_t,
uint8_t,
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(
z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
Sequence<m0, DropoutStep, m1, n1, n2, n3, n4, m2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
1,
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1] / DropoutMThread,
0,
0,
wave_m_n_id[I0],
0,
wave_m_n_id[I1] % DropoutMThread)};
wave_m_n_id[I1])};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
auto z_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<uint8_t,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6),
decltype(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
DropoutNRepeat, // NRepeat
DropoutStep, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
m2,
DropoutGroupPerTile,
n2,
n3,
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11>,
11, // DstVectorDim
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1] / DropoutMThread,
0,
wave_m_n_id[I1] % DropoutMThread,
wave_m_n_id[I1],
0,
wave_m_n_id[I0],
0),
......@@ -1321,8 +1271,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
blockwise_softmax.Run(acc_thread_buf, workspace_buf);
constexpr auto iterator_offset = Number<8 * DropoutStep>{};
constexpr auto iterator_step = Number<n0 * n1 * n2 * n3 * n4 / 8 / DropoutStep>{};
constexpr auto iterator_offset = Number<16 * DropoutStep>{};
constexpr auto iterator_step = Number<m0 * n0 * n1 * n2 * n3 * n4 / 16 / DropoutStep>{};
if constexpr(IsDropout) // dropout
{
......@@ -1343,18 +1293,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
decltype(DropoutTile)>(
ph, global_elem_id, z_tensor_buffer);
z_tmp_thread_copy_vgpr_to_lds.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf);
z_shuffle_thread_copy_lds_to_vgpr.Run(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
z_block_buf,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
......@@ -1367,14 +1316,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
if(p_z_grid && (gemm1_n_block_data_idx_on_grid == 0))
{
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
});
}
......
......@@ -84,6 +84,19 @@ class philox
out_tmp[3] = tmp_ph.w;
}
__device__ void get_random_16x8(uint8_t* out, const unsigned long long subsequence)
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp_ph.x;
out_tmp[1] = tmp_ph.y;
out_tmp[2] = tmp_ph.z;
out_tmp[3] = tmp_ph.w;
}
__device__ void get_random_4x16(ushort* out, const unsigned long long subsequence)
{
uint4 tmp_ph;
......
......@@ -25,19 +25,19 @@ struct ReferenceDropout : public device::BaseOperator
Argument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout_in_16bits,
RefDataType p_dropout_in_uint8_t,
float rp_dropout)
: ref_(ref),
in_(in),
out_(out),
p_dropout_in_16bits_(p_dropout_in_16bits),
p_dropout_in_uint8_t_(p_dropout_in_uint8_t),
rp_dropout_(rp_dropout)
{
}
const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
RefDataType p_dropout_in_16bits_;
RefDataType p_dropout_in_uint8_t_;
float rp_dropout_;
};
......@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) =
arg.ref_(idx) <= arg.p_dropout_in_16bits_
arg.ref_(idx) <= arg.p_dropout_in_uint8_t_
? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) *
ck::type_convert<float>(arg.rp_dropout_))
: 0;
......@@ -74,10 +74,10 @@ struct ReferenceDropout : public device::BaseOperator
static auto MakeArgument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout_in_16bits,
RefDataType p_dropout_in_uint8_t,
float rp_dropout)
{
return Argument{ref, in, out, p_dropout_in_16bits, rp_dropout};
return Argument{ref, in, out, p_dropout_in_uint8_t, rp_dropout};
}
static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment