Commit 8ced5c4f authored by danyao12's avatar danyao12
Browse files

bias examples sync with uint8 dropout

parent 0353c29e
...@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -252,7 +252,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -252,7 +252,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -328,7 +328,7 @@ int run(int argc, char* argv[]) ...@@ -328,7 +328,7 @@ int run(int argc, char* argv[])
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
...@@ -655,7 +655,7 @@ int run(int argc, char* argv[]) ...@@ -655,7 +655,7 @@ int run(int argc, char* argv[])
lse_g_m, lse_g_m,
p_drop_g_m_n, p_drop_g_m_n,
z_g_m_n, z_g_m_n,
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
...@@ -715,7 +715,7 @@ int run(int argc, char* argv[]) ...@@ -715,7 +715,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......
...@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m, TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n, TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n, TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits, ZDataType p_dropout_in_uint8_t,
float rp_dropout) float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
...@@ -251,7 +251,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -251,7 +251,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout); ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V // Y = P_dropout * V
...@@ -315,7 +315,7 @@ int run(int argc, char* argv[]) ...@@ -315,7 +315,7 @@ int run(int argc, char* argv[])
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
...@@ -719,7 +719,7 @@ int run(int argc, char* argv[]) ...@@ -719,7 +719,7 @@ int run(int argc, char* argv[])
lse_g_ms[i], lse_g_ms[i],
p_drop_g_m_ns[i], p_drop_g_m_ns[i],
z_g_m_ns[i], z_g_m_ns[i],
p_dropout_in_16bits, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) { y_tensors[i].ForEach([&](auto& self, auto idx) {
...@@ -772,7 +772,7 @@ int run(int argc, char* argv[]) ...@@ -772,7 +772,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
......
...@@ -67,7 +67,7 @@ int run(int argc, char* argv[]) ...@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
...@@ -172,6 +172,7 @@ int run(int argc, char* argv[]) ...@@ -172,6 +172,7 @@ int run(int argc, char* argv[])
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -322,7 +323,9 @@ int run(int argc, char* argv[]) ...@@ -322,7 +323,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias // bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); }); acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx));
});
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
...@@ -342,7 +345,7 @@ int run(int argc, char* argv[]) ...@@ -342,7 +345,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout); z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// gemm1 // gemm1
......
...@@ -44,7 +44,7 @@ int run(int argc, char* argv[]) ...@@ -44,7 +44,7 @@ int run(int argc, char* argv[])
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1; // scaling after 1st gemm float alpha = 1; // scaling after 1st gemm
...@@ -163,8 +163,9 @@ int run(int argc, char* argv[]) ...@@ -163,8 +163,9 @@ int run(int argc, char* argv[])
int Batch = G0 * G1; int Batch = G0 * G1;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte +=
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(CDataType) * M * O +
sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) * sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
Batch; Batch;
...@@ -237,6 +238,7 @@ int run(int argc, char* argv[]) ...@@ -237,6 +238,7 @@ int run(int argc, char* argv[])
b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data()); b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data());
b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data()); b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data());
d_tensors_device[i]->ToDevice(d_gs_ms_ns.mData.data()); d_tensors_device[i]->ToDevice(d_gs_ms_ns.mData.data());
z_tensors_device[i]->ToDevice(z_gs_ms_ns.mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
...@@ -396,7 +398,9 @@ int run(int argc, char* argv[]) ...@@ -396,7 +398,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias // bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); }); acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx));
});
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
...@@ -419,7 +423,7 @@ int run(int argc, char* argv[]) ...@@ -419,7 +423,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout); z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// gemm 1 // gemm 1
......
...@@ -57,8 +57,8 @@ struct GridwiseBatchedDropout ...@@ -57,8 +57,8 @@ struct GridwiseBatchedDropout
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2 static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time // get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16 static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -241,7 +241,7 @@ struct GridwiseBatchedDropout ...@@ -241,7 +241,7 @@ struct GridwiseBatchedDropout
// only used for providing ApplyDropoutAttnBwdSaveZ // only used for providing ApplyDropoutAttnBwdSaveZ
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{ auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
static_cast<unsigned short>(0.8f * 65535.f), static_cast<FloatGemmAcc>(1.0f / 0.8f)}; static_cast<unsigned short>(0.8f * 255.f), static_cast<FloatGemmAcc>(1.0f / 0.8f)};
// //
// z vgpr copy to global // z vgpr copy to global
...@@ -260,7 +260,7 @@ struct GridwiseBatchedDropout ...@@ -260,7 +260,7 @@ struct GridwiseBatchedDropout
n2)); // NPerXdl n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ushort, uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tensor_buffer; z_tensor_buffer;
...@@ -273,7 +273,7 @@ struct GridwiseBatchedDropout ...@@ -273,7 +273,7 @@ struct GridwiseBatchedDropout
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, uint8_t,
ZDataType, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment