Commit ed305f6b authored by Umang Yadav's avatar Umang Yadav
Browse files

formatting

parent 9f4e3544
...@@ -108,12 +108,12 @@ struct TensorAdaptor ...@@ -108,12 +108,12 @@ struct TensorAdaptor
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension() __host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{ {
constexpr auto all_low_dim_ids = constexpr auto all_low_dim_ids = unpack(
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, [](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{}); LowerDimensionHiddenIdss{});
constexpr auto all_up_dim_ids = constexpr auto all_up_dim_ids = unpack(
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, [](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{}); UpperDimensionHiddenIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
...@@ -338,7 +338,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -338,7 +338,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran];
// sequence in, sequence out // sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr { constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique // shift hidden id so every dim id is unique
...@@ -360,7 +361,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -360,7 +361,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}); });
return low_dim_hidden_ids_1_mod_; return low_dim_hidden_ids_1_mod_;
}(); }
();
return generate_sequence_v2( return generate_sequence_v2(
[&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; }, [&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; },
...@@ -382,7 +384,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -382,7 +384,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran];
// sequence in, constexpr tuple out // sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr { constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id // shift hidden id
...@@ -391,7 +394,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -391,7 +394,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}); });
return up_dim_hidden_ids_1_mod_; return up_dim_hidden_ids_1_mod_;
}(); }
();
// constexpr tuple to sequence // constexpr tuple to sequence
return generate_sequence_v2( return generate_sequence_v2(
......
...@@ -94,8 +94,10 @@ struct SpaceFillingCurve ...@@ -94,8 +94,10 @@ struct SpaceFillingCurve
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index. // idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE. // All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr { constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr { {
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
{
auto res = idx_1d.value; auto res = idx_1d.value;
auto id = 0; auto id = 0;
......
...@@ -47,7 +47,6 @@ struct BaseOperator ...@@ -47,7 +47,6 @@ struct BaseOperator
virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
virtual std::string GetTypeString() const { return ""; } virtual std::string GetTypeString() const { return ""; }
...@@ -66,7 +65,7 @@ struct BaseOperator ...@@ -66,7 +65,7 @@ struct BaseOperator
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
{ {
//assert(p_arg); // assert(p_arg);
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
} }
......
...@@ -38,7 +38,7 @@ template <typename GridwiseGemm, ...@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_contraction_multiple_d_xdl_cshuffle( kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
......
...@@ -60,7 +60,7 @@ template <typename GridwiseGemm, ...@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
......
...@@ -41,9 +41,10 @@ template <typename GridwiseGemm, ...@@ -41,9 +41,10 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
......
...@@ -63,7 +63,7 @@ template <typename GridwiseGemm, ...@@ -63,7 +63,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
......
...@@ -52,7 +52,7 @@ template <typename GridwiseGemm, ...@@ -52,7 +52,7 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dl_multiple_d( kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
......
...@@ -41,7 +41,7 @@ template <typename GridwiseGemm, ...@@ -41,7 +41,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_gemm_xdl_cshuffle_v1( kernel_batched_gemm_gemm_xdl_cshuffle_v1(
const A0B0B1DataType* __restrict__ p_a0_grid, const A0B0B1DataType* __restrict__ p_a0_grid,
......
...@@ -38,7 +38,7 @@ template <typename GridwiseGemm, ...@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
bool HasMainK0BlockLoop> bool HasMainK0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_reduce_xdl_cshuffle_v1( kernel_batched_gemm_reduce_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
......
...@@ -42,7 +42,7 @@ template <typename GridwiseGemm, ...@@ -42,7 +42,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
......
...@@ -40,7 +40,7 @@ template <typename GridwiseGemm, ...@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
...@@ -611,7 +611,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -611,7 +611,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true; return true;
} }
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) static constexpr bool
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{ {
// check vector load/store // check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -842,7 +843,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -842,7 +843,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
template <class ADesc, class BDesc, class B1Desc, class CDesc> template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor struct Descriptor
{ {
template<class AGridDescriptor> template <class AGridDescriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc) static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
{ {
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc); const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
...@@ -852,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -852,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto AK0 = K / AK1; const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k, return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template<class BGridDescriptor> template <class BGridDescriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc) static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
{ {
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc); const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
...@@ -869,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -869,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto BK0 = K / BK1; const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k, return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template<class B1GridDescriptor> template <class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc) static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
{ {
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc); const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
...@@ -894,21 +897,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -894,21 +897,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template<class CGridDescriptor> template <class CGridDescriptor>
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc) static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
{ {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc); return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
} }
using AGridDesc_AK0_M_AK1 = using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>; remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 = using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>; remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDesc_M_N = using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
...@@ -979,7 +980,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -979,7 +980,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CGridDesc_M_N c_grid_desc_m_n; CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask; C0MatrixMask c0_matrix_mask;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock; typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op // element-wise op
AElementwiseOperation a_element_op; AElementwiseOperation a_element_op;
...@@ -1002,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1002,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap( block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{ c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)}, GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)}, c0_matrix_mask{c.GetLength(I1)},
...@@ -1013,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1013,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_element_op{b_element_op_}, b_element_op{b_element_op_},
b1_element_op{b1_element_op_}, b1_element_op{b1_element_op_},
c_element_op{c_element_op_}, c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity( is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n, c_grid_desc_m_n,
block_2_ctile_map) and block_2_ctile_map) and
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1), b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2), a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))} b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{ {
} }
constexpr bool IsValid() const constexpr bool IsValid() const { return is_valid; }
{
return is_valid;
}
}; };
template <class ADesc, class BDesc, class B1Desc, class CDesc> template <class ADesc, class BDesc, class B1Desc, class CDesc>
...@@ -1061,7 +1060,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1061,7 +1060,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
if(desc.has_main_k_block_loop) if(desc.has_main_k_block_loop)
{ {
Desc::GridwiseGemm::template Run<true>(p_a_grid, Desc::GridwiseGemm::template Run<true>(
p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
...@@ -1080,7 +1080,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1080,7 +1080,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
else else
{ {
Desc::GridwiseGemm::template Run<false>(p_a_grid, Desc::GridwiseGemm::template Run<false>(
p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
......
...@@ -48,7 +48,7 @@ namespace device { ...@@ -48,7 +48,7 @@ namespace device {
template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop> template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
{ {
......
...@@ -34,7 +34,7 @@ template <typename GridwiseGemm, ...@@ -34,7 +34,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_contraction_multiple_d_xdl_cshuffle( kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
......
...@@ -37,7 +37,7 @@ template <typename GridwiseGemm, ...@@ -37,7 +37,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r3_for_conv3d( kernel_gemm_xdlops_v2r3_for_conv3d(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment