Unverified Commit 58d75b7a authored by M.Emin Ozturk's avatar M.Emin Ozturk Committed by GitHub
Browse files

Merge branch 'develop' into gemm_bf16_sk_muozturk

parents 7ed95722 627a27bd
......@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
CDEElementwiseOperation cde_element_op,
index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -41,12 +41,15 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer
......@@ -54,8 +57,8 @@ __global__ void
});
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid + c_batch_offset,
p_shared,
......@@ -87,12 +90,15 @@ __global__ void
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer
......@@ -100,8 +106,8 @@ __global__ void
});
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid + c_batch_offset,
p_shared_0,
......@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t Batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
CElementwiseOperation c_element_op_,
index_t KBatch_)
: GridwiseGemm::Argument{p_a_grid_,
p_b_grid_,
p_ds_grid_,
......@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
StrideB_,
StrideDs_,
StrideE_,
1,
KBatch_,
a_element_op_,
b_element_op_,
c_element_op_},
......@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1)
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch);
std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch);
float ave_time = 0;
......@@ -387,9 +395,10 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
hipGetErrorString(
hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
arg.Batch * arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
......@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
index_t KBatch = 1)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch,
a_element_op,
b_element_op,
c_element_op};
c_element_op,
KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
CElementwiseOperation c_element_op,
index_t KBatch = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch,
a_element_op,
b_element_op,
c_element_op);
c_element_op,
KBatch);
}
// polymorphic
......
......@@ -729,6 +729,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return str.str();
}
REGISTER_EXTRA_PRINTING_METHODS
};
} // namespace device
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -106,8 +106,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static constexpr auto I3 = Number<3>{};
static constexpr index_t KPerBlock = K0PerBlock * K1;
static constexpr auto transform_conv_to_gemm =
TransformConvBwdDataToGemm_v1<NDimSpatial,
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
K1,
K1,
......@@ -115,80 +114,27 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
NPerBlock,
KPerBlock,
true /* DoPadGemmM */,
true /* DoPadGemmN */>{};
true /* DoPadGemmN */,
ALayout,
BLayout,
ELayout>;
static auto GetDummyABDsEGridDescriptor()
static auto
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
{
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto ds_grid_desc_m_n = generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
},
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
const auto ds_grid_desc_m_n =
generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); },
Number<NumDTensor>{});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
}
// desc
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
......@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths,
/*ds_g_n_c_wis_lengths*/,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
......@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
......@@ -382,9 +321,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
tildes = {i_ztilde, i_ytilde, i_xtilde};
}
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
a_g_n_k_wos_lengths,
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -394,56 +331,37 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
tildes};
const auto a_grid_desc_ak0_m_ak1 =
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
static_assert(is_same_v<DLayout, ELayout>);
ConvToGemmBwdDataTransform conv_to_gemm_transform_d{
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
e_g_n_c_wis_lengths,
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
tildes};
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
// for check validity
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
......@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
......
......@@ -41,7 +41,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
......@@ -76,7 +76,7 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
......@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead;
a_k_split_offset = k_id * karg.KRead;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead;
b_k_split_offset = k_id * karg.KRead;
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
......
......@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template <>
inline __device__ half_t neg<half_t>(half_t x)
{
return __hneg(x);
return __hneg(static_cast<__half>(x));
};
template <typename T>
......
......@@ -45,5 +45,8 @@ our implementation of different device operators.
**[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
**[ref]**
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
## examples
currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder.
......@@ -54,6 +54,7 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
......
......@@ -5,6 +5,7 @@
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -998,14 +998,14 @@ struct FmhaFwdKernel
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
const auto k_dram = [&]() {
......@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
......@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{});
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
......@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<false, kPadSeqLenK>{});
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
}();
......@@ -1097,8 +1097,9 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentBias>{},
number<1>{});
return pad_tensor_view(
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
return pad_tensor_view(bias_dram_naive,
bias_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
}();
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
......
......@@ -810,21 +810,46 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
{
using S_ = typename Problem::BlockShape;
using T_ = typename Problem::Traits;
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == false)
{
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == false)
{
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
}
};
......
......@@ -22,7 +22,8 @@ template <bool IsGateOnly_,
FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
bool PadHiddenSize_ = false,
bool PadIntermediateSize_ = false>
bool PadIntermediateSize_ = false,
bool PipeInterleave_ = true>
struct FusedMoeGemmTraits
{
// Gate+Up or Gate only
......@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
static constexpr bool PadHiddenSize = PadHiddenSize_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
static constexpr bool PipeInterleave = PipeInterleave_;
};
// Note: this need to be a bit mask
......
......@@ -23,10 +23,10 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
......
......@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
index_t m;
index_t n;
index_t stride; // row_stride
index_t x_stride; // input row_stride
index_t y_stride; // output row_stride
};
// TODO: Extract some type to wrapper class
......@@ -58,14 +59,21 @@ struct Smoothquant
index_t m;
index_t n;
index_t stride; // row_stride
index_t x_stride; // input row_stride
index_t y_stride; // out row_stride
};
using Hargs = SmoothquantHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{
hargs.p_x, hargs.p_xscale, hargs.p_yscale, hargs.p_qy, hargs.m, hargs.n, hargs.stride};
return Kargs{hargs.p_x,
hargs.p_xscale,
hargs.p_yscale,
hargs.p_qy,
hargs.m,
hargs.n,
hargs.x_stride,
hargs.y_stride};
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
......@@ -116,7 +124,7 @@ struct Smoothquant
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
make_tuple(kargs.x_stride, 1),
number<Vector_N>{},
number<1>{});
......@@ -157,7 +165,7 @@ struct Smoothquant
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<QYDataType*>(kargs.p_qy),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
make_tuple(kargs.y_stride, 1),
number<Vector_N>{},
number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment