"git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "80222c63fd9531048e801b2ae0e43a2bd08f6d6b"
Commit 93235bb4 authored by Anthony Chang's avatar Anthony Chang
Browse files

fully implemented gemm + bias + activation + add + norm

parent 31b3f1dc
...@@ -142,6 +142,7 @@ int main(int argc, char* argv[]) ...@@ -142,6 +142,7 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<C0DataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)}))); Tensor<C0DataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<C0DataType> c0_m_n_add(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<C0DataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)}))); Tensor<C0DataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<C0DataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)}))); Tensor<C0DataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
...@@ -149,6 +150,7 @@ int main(int argc, char* argv[]) ...@@ -149,6 +150,7 @@ int main(int argc, char* argv[])
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "c0_n_bias: " << c0_n_bias.mDesc << std::endl; std::cout << "c0_n_bias: " << c0_n_bias.mDesc << std::endl;
std::cout << "c0_m_n_add: " << c0_m_n_add.mDesc << std::endl;
std::cout << "c0_n_gamma: " << c0_n_gamma.mDesc << std::endl; std::cout << "c0_n_gamma: " << c0_n_gamma.mDesc << std::endl;
std::cout << "c0_n_beta: " << c0_n_beta.mDesc << std::endl; std::cout << "c0_n_beta: " << c0_n_beta.mDesc << std::endl;
...@@ -169,6 +171,7 @@ int main(int argc, char* argv[]) ...@@ -169,6 +171,7 @@ int main(int argc, char* argv[])
} }
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5}); c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
c0_m_n_add.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
c0_n_gamma.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 2}); c0_n_gamma.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 2});
c0_n_beta.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 5}); c0_n_beta.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 5});
c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0}); c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0});
...@@ -178,12 +181,14 @@ int main(int argc, char* argv[]) ...@@ -178,12 +181,14 @@ int main(int argc, char* argv[])
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpace()); DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpace());
DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpace());
DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpace()); DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpace());
DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpace()); DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c0_bias_buf.ToDevice(c0_n_bias.mData.data()); c0_bias_buf.ToDevice(c0_n_bias.mData.data());
c0_add_buf.ToDevice(c0_m_n_add.mData.data());
c0_gamma_buf.ToDevice(c0_n_gamma.mData.data()); c0_gamma_buf.ToDevice(c0_n_gamma.mData.data());
c0_beta_buf.ToDevice(c0_n_beta.mData.data()); c0_beta_buf.ToDevice(c0_n_beta.mData.data());
...@@ -198,6 +203,7 @@ int main(int argc, char* argv[]) ...@@ -198,6 +203,7 @@ int main(int argc, char* argv[])
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_add_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_bias_buf.GetDeviceBuffer()), static_cast<C0DataType*>(c0_bias_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_gamma_buf.GetDeviceBuffer()), static_cast<C0DataType*>(c0_gamma_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_beta_buf.GetDeviceBuffer()), static_cast<C0DataType*>(c0_beta_buf.GetDeviceBuffer()),
...@@ -244,10 +250,11 @@ int main(int argc, char* argv[]) ...@@ -244,10 +250,11 @@ int main(int argc, char* argv[])
auto ref_argument = ref_gemm.MakeArgument(a_m_k, auto ref_argument = ref_gemm.MakeArgument(a_m_k,
b_k_n, b_k_n,
c_m_n_host_result,
c0_n_bias, c0_n_bias,
c0_m_n_add,
c0_n_gamma, c0_n_gamma,
c0_n_beta, c0_n_beta,
c_m_n_host_result,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
......
...@@ -431,9 +431,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -431,9 +431,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const C0DataType* p_c0_bias, const C0DataType* p_c0_grid_add,
const C0DataType* p_c0_gamma, const C0DataType* p_c0_grid_bias,
const C0DataType* p_c0_beta, const C0DataType* p_c0_grid_gamma,
const C0DataType* p_c0_grid_beta,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -447,9 +448,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -447,9 +448,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_c0_bias_{p_c0_bias}, p_c0_grid_bias_{p_c0_grid_bias},
p_c0_gamma_{p_c0_gamma}, p_c0_grid_add_{p_c0_grid_add},
p_c0_beta_{p_c0_beta}, p_c0_grid_gamma_{p_c0_grid_gamma},
p_c0_grid_beta_{p_c0_grid_beta},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
...@@ -480,9 +482,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -480,9 +482,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const C0DataType* p_c0_bias_; const C0DataType* p_c0_grid_bias_;
const C0DataType* p_c0_gamma_; const C0DataType* p_c0_grid_add_;
const C0DataType* p_c0_beta_; const C0DataType* p_c0_grid_gamma_;
const C0DataType* p_c0_grid_beta_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -564,9 +567,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -564,9 +567,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_bias_, arg.p_c0_grid_bias_,
arg.p_c0_gamma_, arg.p_c0_grid_add_,
arg.p_c0_beta_, arg.p_c0_grid_gamma_,
arg.p_c0_grid_beta_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_, arg.acc_element_op_,
...@@ -603,9 +607,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -603,9 +607,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_bias_, arg.p_c0_grid_bias_,
arg.p_c0_gamma_, arg.p_c0_grid_add_,
arg.p_c0_beta_, arg.p_c0_grid_gamma_,
arg.p_c0_grid_beta_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_, arg.acc_element_op_,
...@@ -656,9 +661,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -656,9 +661,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, CDataType* p_c,
const CDataType* p_c0_bias, const C0DataType* p_c0_bias,
const CDataType* p_c0_gamma, const C0DataType* p_c0_add,
const CDataType* p_c0_beta, const C0DataType* p_c0_gamma,
const C0DataType* p_c0_beta,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -674,6 +680,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -674,6 +680,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
p_b, p_b,
p_c, p_c,
p_c0_bias, p_c0_bias,
p_c0_add,
p_c0_gamma, p_c0_gamma,
p_c0_beta, p_c0_beta,
MRaw, MRaw,
...@@ -694,6 +701,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -694,6 +701,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
const void* p_b, const void* p_b,
void* p_c, void* p_c,
const void* p_c0_bias, const void* p_c0_bias,
const void* p_c0_add,
const void* p_c0_gamma, const void* p_c0_gamma,
const void* p_c0_beta, const void* p_c0_beta,
index_t MRaw, index_t MRaw,
...@@ -712,6 +720,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -712,6 +720,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const C0DataType*>(p_c0_bias), static_cast<const C0DataType*>(p_c0_bias),
static_cast<const C0DataType*>(p_c0_add),
static_cast<const C0DataType*>(p_c0_gamma), static_cast<const C0DataType*>(p_c0_gamma),
static_cast<const C0DataType*>(p_c0_beta), static_cast<const C0DataType*>(p_c0_beta),
MRaw, MRaw,
......
...@@ -39,6 +39,7 @@ __global__ void ...@@ -39,6 +39,7 @@ __global__ void
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, // MxN FloatC* __restrict__ p_c_grid, // MxN
const FloatC0* __restrict__ p_c0_bias_grid, // 1xN const FloatC0* __restrict__ p_c0_bias_grid, // 1xN
const FloatC0* __restrict__ p_c0_add_grid, // MxN
const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC0* __restrict__ p_c0_beta_grid, // 1xN const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -60,6 +61,7 @@ __global__ void ...@@ -60,6 +61,7 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_bias_grid, p_c0_bias_grid,
p_c0_add_grid,
p_c0_gamma_grid, p_c0_gamma_grid,
p_c0_beta_grid, p_c0_beta_grid,
p_shared, p_shared,
...@@ -79,6 +81,7 @@ __global__ void ...@@ -79,6 +81,7 @@ __global__ void
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_c0_bias_grid; ignore = p_c0_bias_grid;
ignore = p_c0_add_grid;
ignore = p_c0_gamma_grid; ignore = p_c0_gamma_grid;
ignore = p_c0_beta_grid; ignore = p_c0_beta_grid;
ignore = a_element_op; ignore = a_element_op;
...@@ -350,6 +353,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -350,6 +353,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_bias_grid, // 1xN const FloatC0* __restrict__ p_c0_bias_grid, // 1xN
const FloatC0* __restrict__ p_c0_add_grid, // MxN
const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC0* __restrict__ p_c0_beta_grid, // 1xN const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -372,6 +376,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -372,6 +376,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_bias_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c0_bias_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_bias_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); p_c0_bias_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
// Note: c0_add is of same layout as c so we don't declare new c0_add_desc here
auto c0_add_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_add_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_gamma_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c0_gamma_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_gamma_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); p_c0_gamma_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
auto c0_beta_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c0_beta_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -814,10 +821,28 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -814,10 +821,28 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1, 1,
true>( true>(
c0_grid_desc_mblock_mperblock_nblock_nperblock, c0_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0, make_multi_index(block_work_idx[I0],
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], c_reduce_thread_data_idx_begin[I0],
I0, block_work_idx[I1],
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); c_reduce_thread_data_idx_begin[I1]));
// Note: c0_add is of same layout as c so we don't declare new c0_add_desc here
auto c0_add_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC0,
FloatC0,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
1,
true>(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0],
c_reduce_thread_data_idx_begin[I0],
block_work_idx[I1],
c_reduce_thread_data_idx_begin[I1]));
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
...@@ -880,6 +905,19 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -880,6 +905,19 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_cast<FloatReduceAcc>(c0_thread_buf(i)); // bias static_cast<FloatReduceAcc>(c0_thread_buf(i)); // bias
}); });
c0_add_thread_copy_global_to_vgpr.Run(
c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_add_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
c_reduce_thread_buf(i) +=
static_cast<FloatReduceAcc>(c0_thread_buf(i)); // add
});
using ThreadwiseReduceD0 = using ThreadwiseReduceD0 =
ThreadwiseReduction<FloatReduceAcc, ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock), decltype(c_reduce_thread_desc_mperblock_nperblock),
...@@ -1010,6 +1048,10 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -1010,6 +1048,10 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// move on C0 // move on C0
c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on C0_add
c0_add_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
} }
......
...@@ -33,6 +33,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -33,6 +33,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
static void RunLayernorm(Tensor<OutDataType>& result, static void RunLayernorm(Tensor<OutDataType>& result,
const Tensor<ComputeDataType>& acc, // MxN const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& bias, // 1xN const Tensor<InDataType>& bias, // 1xN
const Tensor<InDataType>& add, // MxN
const Tensor<InDataType>& gamma, // 1xN const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN const Tensor<InDataType>& beta, // 1xN
const InDataType epsilon = 1e-5) const InDataType epsilon = 1e-5)
...@@ -53,6 +54,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -53,6 +54,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
self(idx[0], idx[1]) = acc(idx[0], idx[1]) + bias(idx[1]); self(idx[0], idx[1]) = acc(idx[0], idx[1]) + bias(idx[1]);
}); });
// add from other layer
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) += add(idx[0], idx[1]);
});
// reduce N dim // reduce N dim
for(size_t i = 0; i < M; i++) for(size_t i = 0; i < M; i++)
{ {
...@@ -88,10 +94,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -88,10 +94,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
{ {
Argument(const Tensor<ADataType>& a_m_k, Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n, const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<C0DataType>& c0_n_bias, // 1xN const Tensor<C0DataType>& c0_n_bias, // 1xN
const Tensor<C0DataType>& c0_m_n_add, // MxN
const Tensor<C0DataType>& c0_n_gamma, // 1xN const Tensor<C0DataType>& c0_n_gamma, // 1xN
const Tensor<C0DataType>& c0_n_beta, // 1xN const Tensor<C0DataType>& c0_n_beta, // 1xN
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -99,10 +106,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -99,10 +106,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
const CDataType epsilon = 1e-5) const CDataType epsilon = 1e-5)
: a_m_k_{a_m_k}, : a_m_k_{a_m_k},
b_k_n_{b_k_n}, b_k_n_{b_k_n},
c_m_n_{c_m_n},
c0_n_bias_{c0_n_bias}, c0_n_bias_{c0_n_bias},
c0_m_n_add_{c0_m_n_add},
c0_n_gamma_{c0_n_gamma}, c0_n_gamma_{c0_n_gamma},
c0_n_beta_{c0_n_beta}, c0_n_beta_{c0_n_beta},
c_m_n_{c_m_n},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
...@@ -113,10 +121,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -113,10 +121,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
const Tensor<ADataType>& a_m_k_; const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_; const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m_n_;
const Tensor<C0DataType>& c0_n_bias_; const Tensor<C0DataType>& c0_n_bias_;
const Tensor<C0DataType>& c0_m_n_add_;
const Tensor<C0DataType>& c0_n_gamma_; const Tensor<C0DataType>& c0_n_gamma_;
const Tensor<C0DataType>& c0_n_beta_; const Tensor<C0DataType>& c0_n_beta_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
...@@ -145,10 +154,13 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -145,10 +154,13 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_); arg.acc_element_op_);
// gemm
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_bias_, arg.c0_n_gamma_, arg.c0_n_beta_); // layernorm
RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_bias_, arg.c0_m_n_add_, arg.c0_n_gamma_, arg.c0_n_beta_);
// elementwise op
arg.c_m_n_.ForEach([&](auto& self, auto idx) { arg.c_m_n_.ForEach([&](auto& self, auto idx) {
arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1])); arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1]));
}); });
...@@ -173,10 +185,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -173,10 +185,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
static auto MakeArgument(const Tensor<ADataType>& a_m_k, static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n, const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<C0DataType>& c0_n_bias, // 1xN const Tensor<C0DataType>& c0_n_bias, // 1xN
const Tensor<C0DataType>& c0_m_n_add, // 1xN
const Tensor<C0DataType>& c0_n_gamma, // 1xN const Tensor<C0DataType>& c0_n_gamma, // 1xN
const Tensor<C0DataType>& c0_n_beta, // 1xN const Tensor<C0DataType>& c0_n_beta, // 1xN
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -185,10 +198,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -185,10 +198,11 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
{ {
return Argument{a_m_k, return Argument{a_m_k,
b_k_n, b_k_n,
c_m_n,
c0_n_bias, c0_n_bias,
c0_m_n_add,
c0_n_gamma, c0_n_gamma,
c0_n_beta, c0_n_beta,
c_m_n,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment