Commit 8e7b98eb authored by letaoqin's avatar letaoqin
Browse files

group gemm add kernel

parent b158c537
......@@ -76,7 +76,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
static constexpr bool Deterministic = false;
#if(DIM <= 32)
using DeviceGemmInstance =
......
......@@ -57,7 +57,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0;
std::vector<const void*> p_b1;
std::vector<void*> p_c;
std::vector<const void*> p_d;
std::vector<std::vector<const void*>> p_d;
std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse;
......@@ -122,8 +122,8 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
......@@ -159,7 +159,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
......@@ -176,6 +176,7 @@ int run(int argc, char* argv[])
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", "
<< "d_gs_ms_ns[" << i << "]: " << d_gs_ms_ns.mDesc << ", "
<< "z_gs_ms_ns[" << i << "]: " << z_gs_ms_ns.mDesc << ", "
<< "lse_gs_ms_os[" << i << "]: " << lse_gs_ms_device_result.mDesc
<< std::endl;
......@@ -190,7 +191,7 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1, 1});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
......@@ -243,7 +244,9 @@ int run(int argc, char* argv[])
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_d.push_back(d_tensors_device[i]->GetDeviceBuffer());
p_d.push_back({d_tensors_device[i]->GetDeviceBuffer()});
// std::cout << "from host group id: " << i << " d address: " <<
// d_tensors_device[i]->GetDeviceBuffer() << std::endl;
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
......@@ -266,8 +269,8 @@ int run(int argc, char* argv[])
p_c,
p_z_nullptr,
p_lse,
std::vector<std::vector<const void*>>{p_d}, // p_acc0_biases
{}, // p_acc1_biases
p_d, // p_acc0_biases
{}, // p_acc1_biases
problem_descs,
a_element_op,
b0_element_op,
......@@ -309,8 +312,8 @@ int run(int argc, char* argv[])
p_c,
p_z,
p_lse,
{}, // p_acc0_biases
{}, // p_acc1_biases
p_d, // p_acc0_biases
{}, // p_acc1_biases
problem_descs,
a_element_op,
b0_element_op,
......@@ -344,6 +347,7 @@ int run(int argc, char* argv[])
const auto& a_gs_ms_ks = a_tensors[i];
const auto& b0_gs_ns_ks = b0_tensors[i];
const auto& b1_gs_os_ns = b1_tensors[i];
const auto& d_gs_ms_ns = d_tensors[i];
auto& c_gs_ms_os_device_result = c_tensors[i];
auto& z_gs_ms_ns_device_result = z_tensors[i];
auto& lse_gs_ms_device_result = lse_tensors[i];
......@@ -358,7 +362,8 @@ int run(int argc, char* argv[])
Tensor<ADataType> a_g_m_k({G0 * G1, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<AccDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
......@@ -378,6 +383,10 @@ int run(int argc, char* argv[])
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
......@@ -390,6 +399,8 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); });
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
......@@ -470,6 +481,10 @@ int run(int argc, char* argv[])
"Error: Incorrect results lse!",
rtol,
atol);
if(!pass_)
{
std::cout << "from group: " << i << std::endl;
}
pass &= pass_;
}
if(pass)
......
......@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -39,7 +39,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2r2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -95,6 +95,8 @@ __global__ void
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetDBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
......@@ -109,6 +111,9 @@ __global__ void
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
......@@ -126,6 +131,7 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
......@@ -146,6 +152,8 @@ __global__ void
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr
......@@ -162,6 +170,7 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
......@@ -330,6 +339,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
......@@ -495,7 +511,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2<
ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
......@@ -647,20 +663,17 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
b1_element_op_{b1_element_op},
c_element_op_{c_element_op}
{
ignore = p_acc1_biases_vec;
// TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size();
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size()))
group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size() &&
(group_count_ == p_acc0_biases_vec.size() || p_acc0_biases_vec.size() == 0)))
{
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
}
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0;
index_t z_random_matrix_offset = 0;
......@@ -884,18 +897,18 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_,
is_lse_storing_,
Deterministic>;
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2r2<GridwiseGemm,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_,
is_lse_storing_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment