Unverified Commit 4c850c90 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents ce30621d 1973903f
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
#include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
#include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
......
...@@ -63,7 +63,7 @@ using DeviceGemmInstance = ...@@ -63,7 +63,7 @@ using DeviceGemmInstance =
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
// clang-format on // clang-format on
struct ProblemSize final struct ProblemSize final
......
...@@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, ...@@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop) __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{ {
return num_loop > PrefetchStages; return num_loop > PrefetchStages;
} }
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{ {
ignore = num_loop; ignore = num_loop;
return TailNumber::Full; return TailNumber::Full;
...@@ -446,12 +446,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -446,12 +446,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefetchStages = 1; static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop) __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{ {
return num_loop > PrefetchStages; return num_loop > PrefetchStages;
} }
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{ {
ignore = num_loop; ignore = num_loop;
return TailNumber::Full; return TailNumber::Full;
......
...@@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages; static constexpr index_t GlobalBufferNum = PrefetchStages;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop) __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{ {
return num_loop > PrefetchStages; return num_loop > PrefetchStages;
} }
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{ {
if(num_loop % PrefetchStages == 1) if(num_loop % PrefetchStages == 1)
{ {
...@@ -646,12 +646,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -646,12 +646,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages; static constexpr index_t GlobalBufferNum = PrefetchStages;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop) __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{ {
return num_loop > PrefetchStages; return num_loop > PrefetchStages;
} }
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{ {
if(num_loop % PrefetchStages == 1) if(num_loop % PrefetchStages == 1)
{ {
......
...@@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop) __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{ {
return num_loop > PrefetchStages; return num_loop > PrefetchStages;
} }
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{ {
ignore = num_loop; ignore = num_loop;
return TailNumber::Full; return TailNumber::Full;
......
...@@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t GlobalBufferNum = 2; static constexpr index_t GlobalBufferNum = 2;
static constexpr index_t HotloopUnroll = 2; static constexpr index_t HotloopUnroll = 2;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop) __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{ {
return num_loop > PrefetchStages; return num_loop > PrefetchStages;
} }
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{ {
if(num_loop % HotloopUnroll == 1) if(num_loop % HotloopUnroll == 1)
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access // for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i) for(index_t i = 0; i < NumATensor; ++i)
{ {
as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1; tie(as_continous_dim_[i], as_max_read_elems_[i]) =
as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
as_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]); CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
} }
for(index_t i = 0; i < NumBTensor; ++i) for(index_t i = 0; i < NumBTensor; ++i)
{ {
bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1; tie(bs_continous_dim_[i], bs_max_read_elems_[i]) =
bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
bs_max_read_elems_[i] =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]); CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
} }
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]); CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
} }
e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1; tie(e_continous_dim_, e_max_write_elems_) =
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride); CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
} }
// pointers // pointers
...@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive // Describe whether the last part of a given dimension of A/B/D/E is continues dim.
// in the memory or not. std::array<index_t, NumATensor> as_continous_dim_;
std::array<bool, NumATensor> as_mz_consecutive_; std::array<index_t, NumATensor> bs_continous_dim_;
std::array<bool, NumATensor> as_kz_consecutive_; std::array<index_t, NumBTensor> ds_continous_dim_;
std::array<bool, NumBTensor> bs_nz_consecutive_; index_t e_continous_dim_;
std::array<bool, NumBTensor> bs_kz_consecutive_;
std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
std::array<index_t, NumATensor> as_max_read_elems_; std::array<index_t, NumATensor> as_max_read_elems_;
std::array<index_t, NumBTensor> bs_max_read_elems_; std::array<index_t, NumBTensor> bs_max_read_elems_;
...@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_a_vector_size = const bool valid_a_vector_size =
arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0; arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = const bool valid_a_access_dim_m =
ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i]; ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0;
const bool valid_a_access_dim_k = const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i]; ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!((valid_a_vector_size && valid_a_access_dim) || if(!((valid_a_vector_size && valid_a_access_dim) ||
ABlockTransferSrcScalarPerVector == 1)) ABlockTransferSrcScalarPerVector == 1))
...@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_b_vector_size = const bool valid_b_vector_size =
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0; arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i]; BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0;
const bool valid_b_access_dim_k = const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i]; BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!((valid_b_vector_size && valid_b_access_dim) || if(!((valid_b_vector_size && valid_b_access_dim) ||
BBlockTransferSrcScalarPerVector == 1)) BBlockTransferSrcScalarPerVector == 1))
...@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_d_vector_size = const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension. // Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1;
if(!((valid_d_vector_size && valid_d_access_dim) || if(!((valid_d_vector_size && valid_d_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1)) CDEBlockTransferScalarPerVector_NPerBlock == 1))
{ {
...@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_e_vector_size = const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension. // Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_; const bool valid_e_access_dim = arg.e_continous_dim_ == 1;
if(!((valid_e_vector_size && valid_e_access_dim) || if(!((valid_e_vector_size && valid_e_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1)) CDEBlockTransferScalarPerVector_NPerBlock == 1))
{ {
......
...@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} }
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1; tie(a_continous_dim_, a_max_read_elems_) =
a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
a_max_read_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides); CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1; tie(b_continous_dim_, b_max_read_elems_) =
b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
b_max_read_elems_ =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides); CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
} }
e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1; tie(e_continous_dim_, e_max_write_elems_) =
e_max_write_elems_ =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides); CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
} }
...@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive // Describe whether the last part of a given dimension of A/B/D/E is continues dim.
// in the memory or not. index_t a_continous_dim_;
bool a_mz_consecutive_; index_t b_continous_dim_;
bool a_kz_consecutive_; std::array<index_t, NumDTensor> ds_continous_dim_;
bool b_nz_consecutive_; index_t e_continous_dim_;
bool b_kz_consecutive_;
std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
index_t a_max_read_elems_; index_t a_max_read_elems_;
index_t b_max_read_elems_; index_t b_max_read_elems_;
...@@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_a_vector_size = const bool valid_a_vector_size =
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; const bool valid_a_access_dim_m =
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0;
const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1;
const bool valid_a_access_dim = const bool valid_a_access_dim =
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1; valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
if(!(valid_a_vector_size && valid_a_access_dim)) if(!(valid_a_vector_size && valid_a_access_dim))
...@@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_b_vector_size = const bool valid_b_vector_size =
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; const bool valid_b_access_dim_n =
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0;
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1;
const bool valid_b_access_dim = const bool valid_b_access_dim =
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1; valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
if(!(valid_b_vector_size && valid_b_access_dim)) if(!(valid_b_vector_size && valid_b_access_dim))
...@@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension. // Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = const bool valid_d_access_dim =
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1; arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_d_vector_size && valid_d_access_dim)) if(!(valid_d_vector_size && valid_d_access_dim))
{ {
valid_ds_access = false; valid_ds_access = false;
...@@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension. // Vector write of E is always on N dimension.
const bool valid_e_access_dim = const bool valid_e_access_dim =
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1; arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_e_vector_size && valid_e_access_dim)) if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind ...@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
} }
// Determine the beginning and end idx of the group representing the FCD. // Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx; index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
if(strides[NumDim1 - 1] == 1) if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
{ {
begin_idx = 0; // MZ or KZ are ones
end_idx = NumDim1 - 1; bool dims1_are_ones = true;
for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
{
if(lengths[dim_idx] != 1)
{
dims1_are_ones = false;
}
}
if(dims1_are_ones)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
continous_dim = 1;
}
else
{
begin_idx = 0;
end_idx = NumDim1 - 1;
continous_dim = 0;
}
}
else if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
continous_dim = 0;
} }
else if(strides[NumDim1 + NumDim2 - 1] == 1) else if(strides[NumDim1 + NumDim2 - 1] == 1)
{ {
begin_idx = NumDim1; begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1; end_idx = NumDim1 + NumDim2 - 1;
continous_dim = 1;
} }
else else
{ {
// The dimension consecutive in memory is not the last dimension of any group, so only // The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once. // one element can be read/written at once.
return 1; consecutive_stride = 1;
continous_dim = 0;
return make_tuple(continous_dim, consecutive_stride);
} }
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx) for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{ {
if(strides[dim_idx] == consecutive_stride) if(strides[dim_idx] == consecutive_stride)
...@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind ...@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
} }
} }
const index_t max_subsequent_elems = consecutive_stride; const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems; return make_tuple(continous_dim, max_subsequent_elems);
} }
} // namespace device } // namespace device
......
...@@ -93,9 +93,12 @@ __global__ void ...@@ -93,9 +93,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const long_index_t a_batch_offset =
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -54,9 +54,12 @@ __global__ void ...@@ -54,9 +54,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const long_index_t a_batch_offset =
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
......
...@@ -66,9 +66,12 @@ __global__ void ...@@ -66,9 +66,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const long_index_t a_batch_offset =
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
...@@ -59,9 +59,12 @@ __global__ void ...@@ -59,9 +59,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const long_index_t a_batch_offset =
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -113,9 +116,12 @@ __global__ void ...@@ -113,9 +116,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const long_index_t a_batch_offset =
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
......
...@@ -97,9 +97,12 @@ __global__ void ...@@ -97,9 +97,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const long_index_t a_batch_offset =
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -106,10 +106,12 @@ __global__ void ...@@ -106,10 +106,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); const long_index_t e_batch_offset =
const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -170,10 +172,13 @@ __global__ void ...@@ -170,10 +172,13 @@ __global__ void
} }
else else
{ {
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); const long_index_t a_batch_offset =
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_batch_offset + a_n_offset, p_as_grid + a_batch_offset + a_n_offset,
......
...@@ -85,12 +85,17 @@ __global__ void ...@@ -85,12 +85,17 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); const long_index_t a_batch_offset =
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); const long_index_t e_batch_offset =
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -142,12 +147,17 @@ __global__ void ...@@ -142,12 +147,17 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); const long_index_t a_batch_offset =
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); const long_index_t e_batch_offset =
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
......
...@@ -161,11 +161,11 @@ __global__ void ...@@ -161,11 +161,11 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp> #include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
...@@ -42,16 +43,22 @@ namespace device { ...@@ -42,16 +43,22 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
typename ADataType,
typename BDataType,
typename DsDataType, typename DsDataType,
typename EDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
index_t KPerBlock,
typename OffsettedBlockToCTileMap, typename OffsettedBlockToCTileMap,
typename LocalBlock2ETileMap, typename LocalBlock2ETileMap,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -67,6 +74,7 @@ __global__ void ...@@ -67,6 +74,7 @@ __global__ void
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
__shared__ uint8_t p_shared1[shared_size];
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
...@@ -81,27 +89,8 @@ __global__ void ...@@ -81,27 +89,8 @@ __global__ void
index_t gemm_tile_id_start = 0; index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = 0; index_t gemm_tile_id_end = 0;
using AGridDescMK =
remove_cvref_t<decltype(GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(
1, 1, 1))>;
using BGridDescNK =
remove_cvref_t<decltype(GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
1, 1, 1))>;
using EGridDescMN =
remove_cvref_t<decltype(GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
1, 1, 1))>;
using DsGridDescMN =
remove_cvref_t<decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
index_t M = 0, N = 0, K = 0; index_t M = 0, N = 0, K = 0;
index_t StrideA, StrideB, StrideE;
std::array<index_t, NumDTensor> StrideDs;
AGridDescMK a_grid_desc_mk;
BGridDescNK b_grid_desc_nk;
EGridDescMN e_grid_desc_mn;
DsGridDescMN ds_grid_desc_mn;
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
do do
...@@ -127,31 +116,13 @@ __global__ void ...@@ -127,31 +116,13 @@ __global__ void
} }
b2c_tile_map = b2c_tile_map =
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset); OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
gemm_tile_id_start = group_offset; gemm_tile_id_start = group_offset;
gemm_tile_id_end = group_offset + grid_size_grp; gemm_tile_id_end = group_offset + grid_size_grp;
} }
StrideA = gemm_desc_ptr[group_id].StrideA;
StrideB = gemm_desc_ptr[group_id].StrideB;
StrideDs = gemm_desc_ptr[group_id].StrideDs;
StrideE = gemm_desc_ptr[group_id].StrideE;
a_grid_desc_mk =
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
b_grid_desc_nk =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
e_grid_desc_mn =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
M, N, StrideDs[j]);
});
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid; DsGridPointer p_ds_grid;
...@@ -160,42 +131,268 @@ __global__ void ...@@ -160,42 +131,268 @@ __global__ void
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]); p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
}); });
bool has_main_kblock_loop = static constexpr index_t kbatch = 1;
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{})); static constexpr index_t k_grain = kbatch * KPerBlock;
index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
// Update tile offset if we have moved within group // Update tile offset if we have moved within group
b2c_tile_map.UpdateTileOffset(tile_offset); b2c_tile_map.UpdateTileOffset(tile_offset);
if(has_main_kblock_loop) using Problem = typename GridwiseGemm::Problem;
auto problem = Problem(gemm_desc_ptr[group_id].M,
gemm_desc_ptr[group_id].N,
gemm_desc_ptr[group_id].K,
gemm_desc_ptr[group_id].StrideA,
gemm_desc_ptr[group_id].StrideB,
gemm_desc_ptr[group_id].StrideDs,
gemm_desc_ptr[group_id].StrideE,
kbatch);
if(has_main_k_block_loop)
{ {
GridwiseGemm::template Run<true>(gemm_desc_ptr[group_id].p_a_grid, if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
gemm_desc_ptr[group_id].p_b_grid, BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
p_ds_grid, {
gemm_desc_ptr[group_id].p_e_grid, GridwiseGemm::template Run<OffsettedBlockToCTileMap,
static_cast<void*>(p_shared), true,
a_element_op, InMemoryDataOperationEnum::Set,
b_element_op, TailNumber::Full>(
cde_element_op, static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
a_grid_desc_mk, static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
b_grid_desc_nk, p_ds_grid,
ds_grid_desc_mn, static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
e_grid_desc_mn, static_cast<void*>(p_shared),
b2c_tile_map); problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::One>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Full>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Two>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Three>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Four>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Five>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Six>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Seven>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Odd>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
static_cast<void*>(p_shared1),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
else
{
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Even>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
static_cast<void*>(p_shared1),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
} }
else else
{ {
GridwiseGemm::template Run<false>(gemm_desc_ptr[group_id].p_a_grid, if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
gemm_desc_ptr[group_id].p_b_grid, {
p_ds_grid, GridwiseGemm::template Run<OffsettedBlockToCTileMap,
gemm_desc_ptr[group_id].p_e_grid, false,
static_cast<void*>(p_shared), InMemoryDataOperationEnum::Set,
a_element_op, TailNumber::Full>(
b_element_op, static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
cde_element_op, static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
a_grid_desc_mk, p_ds_grid,
b_grid_desc_nk, static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
ds_grid_desc_mn, static_cast<void*>(p_shared),
e_grid_desc_mn, problem,
b2c_tile_map); a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
} }
tile_id += get_grid_size(); tile_id += get_grid_size();
...@@ -253,10 +450,12 @@ template <typename ALayout, ...@@ -253,10 +450,12 @@ template <typename ALayout,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, typename CDEShuffleBlockTransferScalarPerVectors,
PipelineVersion PipelineVer = PipelineVersion::v1, BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
LoopScheduler LoopSched = make_default_loop_scheduler(), BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeDataType = EDataType> typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
: public DeviceGroupedGemmTileLoop<ALayout, : public DeviceGroupedGemmTileLoop<ALayout,
BLayout, BLayout,
...@@ -273,10 +472,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -273,10 +472,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop;
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType, ADataType,
BDataType, BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -284,8 +486,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -284,8 +486,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, GemmSpec,
NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
...@@ -315,58 +516,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -315,58 +516,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEShuffleBlockTransferScalarPerVectors,
LoopSched, BlkGemmPipeSched,
PipelineVer>; BlkGemmPipelineVer,
ComputeTypeA,
template <typename UnderlyingBlockToCTileMap> ComputeTypeB>;
struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t group_offset,
index_t tile_offset)
: block_to_ctile_map_{block_to_ctile_map},
group_offset_{group_offset},
tile_offset_{tile_offset}
{
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N> using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
{ using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t group_offset_;
index_t tile_offset_;
};
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock>;
using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* p_dev_gemm_args_; const void* p_dev_gemm_args_;
int occupancy_num_blocks_; int occupancy_num_blocks_;
int gpu_cu_count_; int gpu_cu_count_;
const std::vector<GemmDesc>& gemm_descs_; const std::vector<GemmDesc>& gemm_descs_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
...@@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm, const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments, KernelArguments,
GemmSpec, GemmSpec,
ADataType,
BDataType,
DsDataType, DsDataType,
EDataType,
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
ELayout, ELayout,
OffsetedLocalBlock2ETileMap, KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap, Block2ETileMap,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation>; CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config); return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
} }
...@@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<< std::endl; << std::endl;
} }
// run multiple kernels
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
...@@ -572,63 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -572,63 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return false; return false;
} }
using DsGridDescMN = remove_cvref_t<
decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
bool supported = true; bool supported = true;
for(const auto& gdesc : arg.gemm_descs_) constexpr index_t k_batch = 1;
for(index_t i = 0; i < arg.group_count_; ++i)
{ {
const auto M = gdesc.M_; std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
const auto N = gdesc.N_; std::array<index_t, NumDTensor> stride_Ds;
const auto K = gdesc.K_; std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
using GridArg = typename GridwiseGemm::Argument;
const auto StrideA = gdesc.stride_A_; GridArg gridwise_arg(nullptr, // p_a_grid,
const auto StrideB = gdesc.stride_B_; nullptr, // p_b_grid,
const auto StrideE = gdesc.stride_C_; placeholder_p_ds_grid, // p_ds_grid,
const auto& StrideDs = gdesc.stride_Ds_; nullptr, // p_e_grid ,
arg.gemm_descs_[i].M_,
// If M dimension is unknown at launch time then validate just NK. arg.gemm_descs_[i].N_,
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on arg.gemm_descs_[i].K_,
// the user. arg.gemm_descs_[i].stride_A_,
if(N * K == 0) arg.gemm_descs_[i].stride_B_,
continue; stride_Ds,
arg.gemm_descs_[i].stride_C_,
const auto a_grid_desc_mk = k_batch,
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA); arg.a_element_op_,
const auto b_grid_desc_nk = arg.b_element_op_,
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB); arg.cde_element_op_);
const auto e_grid_desc_mn =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE); if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
DsGridDescMN ds_grid_desc_mn; GemmSpec == GemmSpecialization::NKPadding ||
static_for<0, NumDTensor, 1>{}([&](auto j) { GemmSpec == GemmSpecialization::MNKPadding ||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>; GemmSpec == GemmSpecialization::KPadding))
ds_grid_desc_mn(j) =
GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
M, N, StrideDs[j]);
});
const auto b2c_tile_map = Block2ETileMap(M, N);
if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map) &&
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
M, N, K)))
{ {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) return false;
{
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
<< K << "] are not supported by current template parameters!"
<< " In " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
}
supported = false;
} }
supported = supported && GridwiseGemm::CheckValidity(gridwise_arg);
} }
return supported; return supported;
...@@ -651,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -651,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm, const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments, KernelArguments,
GemmSpec, GemmSpec,
ADataType,
BDataType,
DsDataType, DsDataType,
EDataType,
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
ELayout, ELayout,
OffsetedLocalBlock2ETileMap, KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap, Block2ETileMap,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation>; CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
int occupancy, num_cu; int occupancy, num_cu;
hip_check_error( hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
...@@ -696,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -696,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm, const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments, KernelArguments,
GemmSpec, GemmSpec,
ADataType,
BDataType,
DsDataType, DsDataType,
EDataType,
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
ELayout, ELayout,
OffsetedLocalBlock2ETileMap, KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap, Block2ETileMap,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation>; CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
int occupancy, num_cu; int occupancy, num_cu;
hip_check_error( hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
...@@ -739,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -739,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
{ {
auto str = std::ostringstream(); auto str = std::ostringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off // clang-format off
str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop" str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
<< "<" << "<"
...@@ -760,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -760,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<< CShuffleMXdlPerWavePerShuffle << ", " << CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< PipelineVer << ", " << "BlkGemmPipelineScheduler: "
<< LoopSched << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -908,6 +908,51 @@ struct OffsettedBlockToCTileMap ...@@ -908,6 +908,51 @@ struct OffsettedBlockToCTileMap
UnderlyingBlockToCTileMap block_to_ctile_map_; UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_; index_t block_start_;
}; };
// second version with 2 offsets
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap2
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t group_offset,
index_t tile_offset)
: block_to_ctile_map_{block_to_ctile_map},
group_offset_{group_offset},
tile_offset_{tile_offset}
{
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t group_offset_;
index_t tile_offset_;
};
/** /**
* @brief Simple tile mapping which creates 3D grid of block of threads. * @brief Simple tile mapping which creates 3D grid of block of threads.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment