Commit de6a70f7 authored by Jing Zhang's avatar Jing Zhang
Browse files

add ds

parent 1d11426a
...@@ -33,9 +33,9 @@ using CShuffleDataType = F16; ...@@ -33,9 +33,9 @@ using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>; using DsDataType = ck::Tuple<>;
using EDataType = F16; using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
using ELayout = Row; using ELayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
...@@ -63,9 +63,9 @@ int main(int argc, char* argv[]) ...@@ -63,9 +63,9 @@ int main(int argc, char* argv[])
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
const int M = 256; const int M = 256 * (rand() % 16 + 1);
const int N = 128; const int N = 128 * (rand() % 16 + 1);
const int K = 64; const int K = 64 * (rand() % 16 + 1);
const int stride_A = K; const int stride_A = K;
const int stride_B = K; const int stride_B = K;
...@@ -112,12 +112,12 @@ int main(int argc, char* argv[]) ...@@ -112,12 +112,12 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{})); Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{})); Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
Tensor<EDataType> c_g_m_n_device_result( Tensor<EDataType> e_g_m_n_device_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, ELayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, ELayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "c_g_m_n: " << c_g_m_n_device_result.mDesc << std::endl; std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -134,35 +134,38 @@ int main(int argc, char* argv[]) ...@@ -134,35 +134,38 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(EDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
// do GEMM // do GEMM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()), {},
c_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
{},
stride_C, stride_C,
batch_stride_A, batch_stride_A,
batch_stride_B, batch_stride_B,
{},
batch_stride_C, batch_stride_C,
batch_count, batch_count,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); cde_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -189,32 +192,21 @@ int main(int argc, char* argv[]) ...@@ -189,32 +192,21 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(e_g_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<EDataType> c_g_m_n_host_result = HostTensorDescriptor( Tensor<EDataType> e_g_m_n_host_result(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1})); f_host_tensor_descriptor(batch_count, M, N, stride_C, ELayout{}));
auto ref_argument = ref_batched_gemm.MakeArgument( auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); a_g_m_k, b_g_k_n, e_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
//for(int b = 0; b < batch_count; b++)
//{
//for(int m = 0; m < M; m++)
//{
//for(int n = 0; n < N; n++)
//{
//c_g_m_n_host_result(b, m, n) = c_g_m_n_host_result(b, m, n);
//}
//}
//}
pass = ck::utils::check_err( pass = ck::utils::check_err(
c_g_m_n_host_result.mData, c_g_m_n_device_result.mData, "Error: Incorrect results c"); e_g_m_n_host_result.mData, e_g_m_n_device_result.mData, "Error: Incorrect results c");
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -29,16 +29,18 @@ struct DeviceBatchedGemmMultiD : public BaseOperator ...@@ -29,16 +29,18 @@ struct DeviceBatchedGemmMultiD : public BaseOperator
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
//std::array<ck::index_t, NumDTensor> StrideDs, std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE, ck::index_t StrideE,
ck::index_t BatchStrideA, ck::index_t BatchStrideA,
ck::index_t BatchStrideB, ck::index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
ck::index_t BatchStrideE, ck::index_t BatchStrideE,
ck::index_t Batch, ck::index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -58,16 +60,17 @@ template <typename ALayout, ...@@ -58,16 +60,17 @@ template <typename ALayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation>
using DeviceBatchedGemmMultiDPtr = std::unique_ptr<DeviceBatchedGemmMultiD<ALayout, using DeviceBatchedGemmMultiDPtr =
BLayout, std::unique_ptr<DeviceBatchedGemmMultiD<ALayout,
CLayout, BLayout,
ADataType, CLayout,
BDataType, ADataType,
DsDataType, BDataType,
EDataType, DsDataType,
AElementwiseOperation, EDataType,
BElementwiseOperation, AElementwiseOperation,
CDEElementwiseOperation>>; BElementwiseOperation,
CDEElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -47,10 +47,12 @@ namespace device { ...@@ -47,10 +47,12 @@ namespace device {
*/ */
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatDsPointer,
typename FloatC, typename FloatC,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -63,18 +65,22 @@ __global__ void ...@@ -63,18 +65,22 @@ __global__ void
#endif #endif
kernel_batched_gemm_xdl(const FloatAB* __restrict__ p_a_grid, kernel_batched_gemm_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatC* __restrict__ p_e_grid, FloatC* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op, const CDEElementwiseOperation cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -84,38 +90,47 @@ __global__ void ...@@ -84,38 +90,47 @@ __global__ void
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( FloatDsPointer p_ds_grid_grp;
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, static constexpr index_t NumDTensor =
ck::Tuple<>{}, DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
p_e_grid + c_batch_offset,
p_shared, static_for<0, NumDTensor, 1>{}(
a_element_op, [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
b_element_op,
c_element_op, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
a_grid_desc_k0_m_k1, p_b_grid + b_batch_offset,
b_grid_desc_k0_n_k1, p_ds_grid_grp,
ck::StaticallyIndexedArray< p_e_grid + e_batch_offset,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, p_shared,
0>{}, a_element_op,
e_grid_desc_mblock_mperblock_nblock_nperblock_, b_element_op,
block_2_ctile_map); cde_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = batch_count; ignore = batch_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif #endif
...@@ -456,8 +471,12 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -456,8 +471,12 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
{ {
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE) index_t BatchStrideE)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideE_(BatchStrideE) : BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{ {
} }
...@@ -471,7 +490,16 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -471,7 +490,16 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
return g_idx * static_cast<long_index_t>(BatchStrideB_); return g_idx * static_cast<long_index_t>(BatchStrideB_);
} }
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]);
});
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return g_idx * static_cast<long_index_t>(BatchStrideE_);
} }
...@@ -479,6 +507,7 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -479,6 +507,7 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
private: private:
index_t BatchStrideA_; index_t BatchStrideA_;
index_t BatchStrideB_; index_t BatchStrideB_;
std::array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; index_t BatchStrideE_;
}; };
...@@ -535,41 +564,46 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -535,41 +564,46 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a_grid, Argument(const void* p_a_grid,
const BDataType* p_b_grid, const void* p_b_grid,
EDataType* p_e_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
index_t BatchStrideA, index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE, index_t BatchStrideE,
index_t Batch, index_t Batch,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) CDEElementwiseOperation cde_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{p_b_grid}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_e_grid_{p_e_grid}, p_ds_grid_{}, // FIXME
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
Batch_(Batch), Batch_(Batch),
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceBatchedGemmMultiDXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, StrideA)}, DeviceBatchedGemmMultiDXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, StrideA)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceBatchedGemmMultiDXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, StrideB)}, DeviceBatchedGemmMultiDXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceBatchedGemmMultiDXdl::MakeEGridDescriptor_M_N(M, N, StrideE)}, e_grid_desc_m_n_{DeviceBatchedGemmMultiDXdl::MakeEGridDescriptor_M_N(M, N, StrideE)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideE}, compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{c_element_op} cde_element_op_{cde_element_op}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -579,6 +613,19 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -579,6 +613,19 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideDs[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
} }
} }
...@@ -646,77 +693,57 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -646,77 +693,57 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0; auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_batched_gemm_xdl< const auto kernel = kernel_batched_gemm_xdl<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
remove_reference_t<DeviceBatchedGemmMultiDXdl::AGridDesc_AK0_M_AK1>, DeviceOp::AGridDesc_AK0_M_AK1,
remove_reference_t<DeviceBatchedGemmMultiDXdl::BGridDesc_BK0_N_BK1>, DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; has_main_loop>;
ave_time = return launch_and_time_kernel(stream_config,
launch_and_time_kernel(stream_config, kernel,
kernel, dim3(grid_size),
dim3(grid_size), dim3(BlockSize),
dim3(BlockSize), 0,
0, arg.p_a_grid_,
arg.p_a_grid_, arg.p_b_grid_,
arg.p_b_grid_, arg.p_ds_grid_,
arg.p_e_grid_, arg.p_e_grid_,
arg.Batch_, arg.Batch_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.b_element_op_, arg.a_element_op_,
arg.cde_element_op_, arg.b_element_op_,
arg.compute_ptr_offset_of_batch_, arg.cde_element_op_,
arg.block_2_ctile_map_); arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
};
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = kernel_batched_gemm_xdl< ave_time = launch_kernel(integral_constant<bool, false>{});
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
EDataType,
remove_reference_t<DeviceBatchedGemmMultiDXdl::AGridDesc_AK0_M_AK1>,
remove_reference_t<DeviceBatchedGemmMultiDXdl::BGridDesc_BK0_N_BK1>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_e_grid_,
arg.Batch_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
...@@ -750,81 +777,94 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -750,81 +777,94 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const void* p_a,
const BDataType* p_b, const void* p_b,
EDataType* p_c, std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
index_t BatchStrideA, index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE, index_t BatchStrideE,
index_t Batch, index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) CDEElementwiseOperation cde_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_ds,
p_c, p_c,
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideDs,
StrideE, StrideE,
BatchStrideA, BatchStrideA,
BatchStrideB, BatchStrideB,
BatchStrideDs,
BatchStrideE, BatchStrideE,
Batch, Batch,
1, 1,
1, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; cde_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(const void* p_a,
void* p_c, const void* p_b,
index_t M, std::array<const void*, NumDTensor> p_ds,
index_t N, void* p_c,
index_t K, index_t M,
index_t StrideA, index_t N,
index_t StrideB, index_t K,
index_t StrideE, index_t StrideA,
index_t BatchStrideA, index_t StrideB,
index_t BatchStrideB, std::array<ck::index_t, NumDTensor> StrideDs,
index_t BatchStrideE, index_t StrideE,
index_t Batch, index_t BatchStrideA,
AElementwiseOperation a_element_op, index_t BatchStrideB,
BElementwiseOperation b_element_op, std::array<ck::index_t, NumDTensor> BatchStrideDs,
CDEElementwiseOperation c_element_op) override index_t BatchStrideE,
index_t Batch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(p_a,
static_cast<const BDataType*>(p_b), p_b,
static_cast<EDataType*>(p_c), p_ds,
p_c,
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideDs,
StrideE, StrideE,
BatchStrideA, BatchStrideA,
BatchStrideB, BatchStrideB,
BatchStrideDs,
BatchStrideE, BatchStrideE,
Batch, Batch,
1, 1,
1, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); cde_element_op);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment