Unverified Commit 4c850c90 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents ce30621d 1973903f
......@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
#include "ck/host_utility/hip_check_error.hpp"
......
......@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
#include "ck/host_utility/hip_check_error.hpp"
......
......@@ -63,7 +63,7 @@ using DeviceGemmInstance =
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
// clang-format on
struct ProblemSize final
......
......@@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
......@@ -446,12 +446,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
......
......@@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
......@@ -646,12 +646,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
......
......@@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
......
......@@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t GlobalBufferNum = 2;
static constexpr index_t HotloopUnroll = 2;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
if(num_loop % HotloopUnroll == 1)
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i)
{
as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1;
as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
as_max_read_elems_[i] =
tie(as_continous_dim_[i], as_max_read_elems_[i]) =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
}
for(index_t i = 0; i < NumBTensor; ++i)
{
bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1;
bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
bs_max_read_elems_[i] =
tie(bs_continous_dim_[i], bs_max_read_elems_[i]) =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
}
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
}
e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
tie(e_continous_dim_, e_max_write_elems_) =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
}
// pointers
......@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
std::array<bool, NumATensor> as_mz_consecutive_;
std::array<bool, NumATensor> as_kz_consecutive_;
std::array<bool, NumBTensor> bs_nz_consecutive_;
std::array<bool, NumBTensor> bs_kz_consecutive_;
std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
std::array<index_t, NumATensor> as_continous_dim_;
std::array<index_t, NumATensor> bs_continous_dim_;
std::array<index_t, NumBTensor> ds_continous_dim_;
index_t e_continous_dim_;
std::array<index_t, NumATensor> as_max_read_elems_;
std::array<index_t, NumBTensor> bs_max_read_elems_;
......@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_a_vector_size =
arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m =
ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i];
ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0;
const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!((valid_a_vector_size && valid_a_access_dim) ||
ABlockTransferSrcScalarPerVector == 1))
......@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_b_vector_size =
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i];
BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0;
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!((valid_b_vector_size && valid_b_access_dim) ||
BBlockTransferSrcScalarPerVector == 1))
......@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1;
if(!((valid_d_vector_size && valid_d_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{
......@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
const bool valid_e_access_dim = arg.e_continous_dim_ == 1;
if(!((valid_e_vector_size && valid_e_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{
......
......@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
// for sanity check of vector memory access
a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
a_max_read_elems_ =
tie(a_continous_dim_, a_max_read_elems_) =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
b_max_read_elems_ =
tie(b_continous_dim_, b_max_read_elems_) =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
}
e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ =
tie(e_continous_dim_, e_max_write_elems_) =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
}
......@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
bool a_mz_consecutive_;
bool a_kz_consecutive_;
bool b_nz_consecutive_;
bool b_kz_consecutive_;
std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
index_t a_continous_dim_;
index_t b_continous_dim_;
std::array<index_t, NumDTensor> ds_continous_dim_;
index_t e_continous_dim_;
index_t a_max_read_elems_;
index_t b_max_read_elems_;
......@@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_a_vector_size =
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim_m =
ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0;
const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1;
const bool valid_a_access_dim =
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
if(!(valid_a_vector_size && valid_a_access_dim))
......@@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_b_vector_size =
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0;
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1;
const bool valid_b_access_dim =
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
if(!(valid_b_vector_size && valid_b_access_dim))
......@@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim =
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1;
arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_d_vector_size && valid_d_access_dim))
{
valid_ds_access = false;
......@@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim =
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1;
arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_e_vector_size && valid_e_access_dim))
{
return false;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
// Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx;
if(strides[NumDim1 - 1] == 1)
index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
{
// MZ or KZ are ones
bool dims1_are_ones = true;
for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
{
if(lengths[dim_idx] != 1)
{
dims1_are_ones = false;
}
}
if(dims1_are_ones)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
continous_dim = 1;
}
else
{
begin_idx = 0;
end_idx = NumDim1 - 1;
continous_dim = 0;
}
}
else if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
continous_dim = 0;
}
else if(strides[NumDim1 + NumDim2 - 1] == 1)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
continous_dim = 1;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return 1;
consecutive_stride = 1;
continous_dim = 0;
return make_tuple(continous_dim, consecutive_stride);
}
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{
if(strides[dim_idx] == consecutive_stride)
......@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
}
const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems;
return make_tuple(continous_dim, max_subsequent_elems);
}
} // namespace device
......
......@@ -93,9 +93,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
......@@ -54,9 +54,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
......
......@@ -66,9 +66,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
......@@ -59,9 +59,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -113,9 +116,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
......
......@@ -97,9 +97,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
......@@ -106,10 +106,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -170,10 +172,13 @@ __global__ void
}
else
{
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_batch_offset + a_n_offset,
......
......@@ -85,12 +85,17 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -142,12 +147,17 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
......
......@@ -161,11 +161,11 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t b_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t e_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
......@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
......@@ -42,16 +43,22 @@ namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
GemmSpecialization GemmSpec,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
index_t KPerBlock,
typename OffsettedBlockToCTileMap,
typename LocalBlock2ETileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -67,6 +74,7 @@ __global__ void
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
__shared__ uint8_t p_shared1[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
......@@ -81,27 +89,8 @@ __global__ void
index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = 0;
using AGridDescMK =
remove_cvref_t<decltype(GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(
1, 1, 1))>;
using BGridDescNK =
remove_cvref_t<decltype(GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
1, 1, 1))>;
using EGridDescMN =
remove_cvref_t<decltype(GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
1, 1, 1))>;
using DsGridDescMN =
remove_cvref_t<decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
index_t M = 0, N = 0, K = 0;
index_t StrideA, StrideB, StrideE;
std::array<index_t, NumDTensor> StrideDs;
AGridDescMK a_grid_desc_mk;
BGridDescNK b_grid_desc_nk;
EGridDescMN e_grid_desc_mn;
DsGridDescMN ds_grid_desc_mn;
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
do
......@@ -127,31 +116,13 @@ __global__ void
}
b2c_tile_map =
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset);
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
gemm_tile_id_start = group_offset;
gemm_tile_id_end = group_offset + grid_size_grp;
}
StrideA = gemm_desc_ptr[group_id].StrideA;
StrideB = gemm_desc_ptr[group_id].StrideB;
StrideDs = gemm_desc_ptr[group_id].StrideDs;
StrideE = gemm_desc_ptr[group_id].StrideE;
a_grid_desc_mk =
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
b_grid_desc_nk =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
e_grid_desc_mn =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
M, N, StrideDs[j]);
});
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid;
......@@ -160,43 +131,269 @@ __global__ void
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
bool has_main_kblock_loop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{}));
static constexpr index_t kbatch = 1;
static constexpr index_t k_grain = kbatch * KPerBlock;
index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
// Update tile offset if we have moved within group
b2c_tile_map.UpdateTileOffset(tile_offset);
if(has_main_kblock_loop)
using Problem = typename GridwiseGemm::Problem;
auto problem = Problem(gemm_desc_ptr[group_id].M,
gemm_desc_ptr[group_id].N,
gemm_desc_ptr[group_id].K,
gemm_desc_ptr[group_id].StrideA,
gemm_desc_ptr[group_id].StrideB,
gemm_desc_ptr[group_id].StrideDs,
gemm_desc_ptr[group_id].StrideE,
kbatch);
if(has_main_k_block_loop)
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Full>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::One>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Full>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Two>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Three>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Four>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Five>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Six>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Seven>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Odd>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
static_cast<void*>(p_shared1),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
else
{
GridwiseGemm::template Run<true>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Even>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
gemm_desc_ptr[group_id].p_e_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
static_cast<void*>(p_shared1),
problem,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map);
}
}
}
else
{
GridwiseGemm::template Run<false>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
false,
InMemoryDataOperationEnum::Set,
TailNumber::Full>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
gemm_desc_ptr[group_id].p_e_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map);
}
}
tile_id += get_grid_size();
tile_offset += get_grid_size();
......@@ -253,10 +450,12 @@ template <typename ALayout,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeDataType = EDataType>
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
: public DeviceGroupedGemmTileLoop<ALayout,
BLayout,
......@@ -273,10 +472,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop;
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
......@@ -284,8 +486,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
......@@ -315,58 +516,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t group_offset,
index_t tile_offset)
: block_to_ctile_map_{block_to_ctile_map},
group_offset_{group_offset},
tile_offset_{tile_offset}
{
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t group_offset_;
index_t tile_offset_;
};
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock>;
using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
// Argument
struct Argument : public BaseArgument
......@@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* p_dev_gemm_args_;
int occupancy_num_blocks_;
int gpu_cu_count_;
const std::vector<GemmDesc>& gemm_descs_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
OffsetedLocalBlock2ETileMap,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
}
......@@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<< std::endl;
}
// run multiple kernels
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
......@@ -572,63 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return false;
}
using DsGridDescMN = remove_cvref_t<
decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
bool supported = true;
for(const auto& gdesc : arg.gemm_descs_)
{
const auto M = gdesc.M_;
const auto N = gdesc.N_;
const auto K = gdesc.K_;
const auto StrideA = gdesc.stride_A_;
const auto StrideB = gdesc.stride_B_;
const auto StrideE = gdesc.stride_C_;
const auto& StrideDs = gdesc.stride_Ds_;
// If M dimension is unknown at launch time then validate just NK.
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on
// the user.
if(N * K == 0)
continue;
const auto a_grid_desc_mk =
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
const auto b_grid_desc_nk =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
const auto e_grid_desc_mn =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
DsGridDescMN ds_grid_desc_mn;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_mn(j) =
GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
M, N, StrideDs[j]);
});
const auto b2c_tile_map = Block2ETileMap(M, N);
constexpr index_t k_batch = 1;
for(index_t i = 0; i < arg.group_count_; ++i)
{
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
std::array<index_t, NumDTensor> stride_Ds;
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
using GridArg = typename GridwiseGemm::Argument;
GridArg gridwise_arg(nullptr, // p_a_grid,
nullptr, // p_b_grid,
placeholder_p_ds_grid, // p_ds_grid,
nullptr, // p_e_grid ,
arg.gemm_descs_[i].M_,
arg.gemm_descs_[i].N_,
arg.gemm_descs_[i].K_,
arg.gemm_descs_[i].stride_A_,
arg.gemm_descs_[i].stride_B_,
stride_Ds,
arg.gemm_descs_[i].stride_C_,
k_batch,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_);
if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map) &&
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
M, N, K)))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
<< K << "] are not supported by current template parameters!"
<< " In " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
}
supported = false;
return false;
}
supported = supported && GridwiseGemm::CheckValidity(gridwise_arg);
}
return supported;
......@@ -651,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
OffsetedLocalBlock2ETileMap,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
......@@ -696,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
OffsetedLocalBlock2ETileMap,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
......@@ -739,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
{
auto str = std::ostringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
<< "<"
......@@ -760,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< PipelineVer << ", "
<< LoopSched
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
<< ">";
// clang-format on
......
......@@ -908,6 +908,51 @@ struct OffsettedBlockToCTileMap
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
};
// second version with 2 offsets
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap2
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t group_offset,
index_t tile_offset)
: block_to_ctile_map_{block_to_ctile_map},
group_offset_{group_offset},
tile_offset_{tile_offset}
{
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t group_offset_;
index_t tile_offset_;
};
/**
* @brief Simple tile mapping which creates 3D grid of block of threads.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment