Commit 8ced5c4f authored by danyao12's avatar danyao12
Browse files

bias examples sync with uint8 dropout

parent 0353c29e
......@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
ZDataType p_dropout_in_uint8_t,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -252,7 +252,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
......@@ -328,7 +328,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
......@@ -655,7 +655,7 @@ int run(int argc, char* argv[])
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
p_dropout_in_uint8_t,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
......@@ -715,7 +715,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......
......@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
ZDataType p_dropout_in_uint8_t,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -251,7 +251,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
......@@ -315,7 +315,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
auto gemm = DeviceGemmInstance{};
......@@ -719,7 +719,7 @@ int run(int argc, char* argv[])
lse_g_ms[i],
p_drop_g_m_ns[i],
z_g_m_ns[i],
p_dropout_in_16bits,
p_dropout_in_uint8_t,
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
......@@ -772,7 +772,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
......
......@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
......@@ -172,6 +172,7 @@ int run(int argc, char* argv[])
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
......@@ -322,7 +323,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx));
});
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
......@@ -342,7 +345,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout);
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// gemm1
......
......@@ -44,7 +44,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_uint8_t = ZDataType(std::floor(p_dropout * 255.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1; // scaling after 1st gemm
......@@ -163,8 +163,9 @@ int run(int argc, char* argv[])
int Batch = G0 * G1;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
num_byte +=
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(CDataType) * M * O +
sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
Batch;
......@@ -237,6 +238,7 @@ int run(int argc, char* argv[])
b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data());
b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data());
d_tensors_device[i]->ToDevice(d_gs_ms_ns.mData.data());
z_tensors_device[i]->ToDevice(z_gs_ms_ns.mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
......@@ -396,7 +398,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx));
});
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
......@@ -419,7 +423,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout);
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_uint8_t, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// gemm 1
......
......@@ -57,8 +57,8 @@ struct GridwiseBatchedDropout
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
// get_random_16x8() generates 16 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 16>{}; // 32
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -241,7 +241,7 @@ struct GridwiseBatchedDropout
// only used for providing ApplyDropoutAttnBwdSaveZ
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
static_cast<unsigned short>(0.8f * 65535.f), static_cast<FloatGemmAcc>(1.0f / 0.8f)};
static_cast<unsigned short>(0.8f * 255.f), static_cast<FloatGemmAcc>(1.0f / 0.8f)};
//
// z vgpr copy to global
......@@ -260,7 +260,7 @@ struct GridwiseBatchedDropout
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
uint8_t,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tensor_buffer;
......@@ -273,7 +273,7 @@ struct GridwiseBatchedDropout
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
uint8_t,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment