Unverified Commit 58d75b7a authored by M.Emin Ozturk's avatar M.Emin Ozturk Committed by GitHub
Browse files

Merge branch 'develop' into gemm_bf16_sk_muozturk

parents 7ed95722 627a27bd
...@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator ...@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op,
index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -41,12 +41,15 @@ __global__ void ...@@ -41,12 +41,15 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch; const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer // D pointer
...@@ -54,8 +57,8 @@ __global__ void ...@@ -54,8 +57,8 @@ __global__ void
}); });
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset, karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
karg.p_c_grid + c_batch_offset, karg.p_c_grid + c_batch_offset,
p_shared, p_shared,
...@@ -87,12 +90,15 @@ __global__ void ...@@ -87,12 +90,15 @@ __global__ void
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch; const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer // D pointer
...@@ -100,8 +106,8 @@ __global__ void ...@@ -100,8 +106,8 @@ __global__ void
}); });
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset, karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
karg.p_c_grid + c_batch_offset, karg.p_c_grid + c_batch_offset,
p_shared_0, p_shared_0,
...@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t Batch_, index_t Batch_,
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_) CElementwiseOperation c_element_op_,
index_t KBatch_)
: GridwiseGemm::Argument{p_a_grid_, : GridwiseGemm::Argument{p_a_grid_,
p_b_grid_, p_b_grid_,
p_ds_grid_, p_ds_grid_,
...@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
StrideB_, StrideB_,
StrideDs_, StrideDs_,
StrideE_, StrideE_,
1, KBatch_,
a_element_op_, a_element_op_,
b_element_op_, b_element_op_,
c_element_op_}, c_element_op_},
...@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
arg.Print(); arg.Print();
} }
if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1) if(!GridwiseGemm::CheckValidity(arg))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch); std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch);
float ave_time = 0; float ave_time = 0;
...@@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if(arg_.KBatch > 1) if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, hipGetErrorString(
0, hipMemsetAsync(arg_.p_c_grid,
arg_.M * arg_.N * sizeof(CDataType), 0,
stream_config.stream_id_)); arg.Batch * arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
index_t KBatch = 1)
{ {
return Argument{static_cast<const ADataType*>(p_a), return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch, Batch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; c_element_op,
KBatch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op,
index_t KBatch = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch, Batch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op,
KBatch);
} }
// polymorphic // polymorphic
......
...@@ -729,6 +729,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -729,6 +729,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return str.str(); return str.str();
} }
REGISTER_EXTRA_PRINTING_METHODS
}; };
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr index_t KPerBlock = K0PerBlock * K1; static constexpr index_t KPerBlock = K0PerBlock * K1;
static constexpr auto transform_conv_to_gemm = using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
TransformConvBwdDataToGemm_v1<NDimSpatial, ConvBackwardDataSpecialization,
ConvBackwardDataSpecialization, K1,
K1, K1,
K1, MPerBlock,
MPerBlock, NPerBlock,
NPerBlock, KPerBlock,
KPerBlock, true /* DoPadGemmM */,
true /* DoPadGemmM */, true /* DoPadGemmN */,
true /* DoPadGemmN */>{}; ALayout,
BLayout,
static auto GetDummyABDsEGridDescriptor() ELayout>;
{
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto ds_grid_desc_m_n = generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
static auto
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
{
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
const auto ds_grid_desc_m_n =
generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); },
Number<NumDTensor>{});
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple( return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
} }
// desc // desc
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>; using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>; using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
...@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths, /*ds_g_n_c_wis_lengths*/,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides, ds_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
...@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
...@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
tildes = {i_ztilde, i_ytilde, i_xtilde}; tildes = {i_ztilde, i_ytilde, i_xtilde};
} }
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes};
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>( conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>( conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
DsGridDesc_M_N ds_grid_desc_m_n; DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc // populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
static_assert(is_same_v<DLayout, ELayout>);
ds_grid_desc_m_n(i) = ConvToGemmBwdDataTransform conv_to_gemm_transform_d{
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths, a_g_n_k_wos_lengths,
a_g_n_k_wos_strides, a_g_n_k_wos_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_c_wis_lengths, e_g_n_c_wis_lengths,
e_g_n_c_wis_strides, ds_g_n_c_wis_strides[i],
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
tildes); tildes};
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
});
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
// for check validity // for check validity
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
...@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
BElementwiseOp b_element_op_; BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_; CDEElementwiseOp cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_; std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_; std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_; std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_; std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_; std::array<index_t, NDimSpatial> input_right_pads_;
}; };
......
...@@ -41,7 +41,7 @@ __global__ void ...@@ -41,7 +41,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -76,7 +76,7 @@ __global__ void ...@@ -76,7 +76,7 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct SplitKBatchOffset struct SplitKBatchOffset
{ {
__device__ SplitKBatchOffset(Argument& karg) __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{ {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead; a_k_split_offset = k_id * karg.KRead;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; a_k_split_offset = k_id * karg.KRead * karg.StrideA;
} }
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; b_k_split_offset = k_id * karg.KRead * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead; b_k_split_offset = k_id * karg.KRead;
} }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(k_id < karg.KBatch - 1)
{ {
karg.K = karg.KRead; karg.K = karg.KRead;
} }
......
...@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x) ...@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template <> template <>
inline __device__ half_t neg<half_t>(half_t x) inline __device__ half_t neg<half_t>(half_t x)
{ {
return __hneg(x); return __hneg(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
......
...@@ -45,5 +45,8 @@ our implementation of different device operators. ...@@ -45,5 +45,8 @@ our implementation of different device operators.
**[ops/epilogue]** **[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues. epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
**[ref]**
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
## examples ## examples
currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder. currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder.
...@@ -54,6 +54,7 @@ ...@@ -54,6 +54,7 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp"
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -998,14 +998,14 @@ struct FmhaFwdKernel ...@@ -998,14 +998,14 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
}(); }();
const auto k_dram = [&]() { const auto k_dram = [&]() {
...@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel ...@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
k_dram_naive, k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{}); sequence<kPadSeqLenK, kPadHeadDimQ>{});
}(); }();
const auto v_dram = [&]() { const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel ...@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_transposed, v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{}); sequence<kPadHeadDimV, kPadSeqLenK>{});
} }
else else
{ {
...@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel ...@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_naive, v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<false, kPadSeqLenK>{}); sequence<kPadHeadDimV, kPadSeqLenK>{});
} }
}(); }();
...@@ -1097,8 +1097,9 @@ struct FmhaFwdKernel ...@@ -1097,8 +1097,9 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentBias>{}, number<FmhaPipeline::kAlignmentBias>{},
number<1>{}); number<1>{});
return pad_tensor_view( return pad_tensor_view(bias_dram_naive,
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{}); bias_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
}(); }();
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
......
...@@ -810,21 +810,46 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -810,21 +810,46 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1() CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
{ {
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
using T_ = typename Problem::Traits;
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> && if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> && std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> && std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 && S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == false)
{ {
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{}; return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
} }
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> && else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> && std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> && std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 && S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == false)
{ {
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{}; return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
} }
} }
}; };
......
...@@ -22,7 +22,8 @@ template <bool IsGateOnly_, ...@@ -22,7 +22,8 @@ template <bool IsGateOnly_,
FusedMoeGemmWeightPermuteEnum PermuteEnum_ = FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten, FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
bool PadHiddenSize_ = false, bool PadHiddenSize_ = false,
bool PadIntermediateSize_ = false> bool PadIntermediateSize_ = false,
bool PipeInterleave_ = true>
struct FusedMoeGemmTraits struct FusedMoeGemmTraits
{ {
// Gate+Up or Gate only // Gate+Up or Gate only
...@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits ...@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_; static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
static constexpr bool PadHiddenSize = PadHiddenSize_; static constexpr bool PadHiddenSize = PadHiddenSize_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_; static constexpr bool PadIntermediateSize = PadIntermediateSize_;
static constexpr bool PipeInterleave = PipeInterleave_;
}; };
// Note: this need to be a bit mask // Note: this need to be a bit mask
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
......
...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs ...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // output row_stride
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
...@@ -58,14 +59,21 @@ struct Smoothquant ...@@ -58,14 +59,21 @@ struct Smoothquant
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // out row_stride
}; };
using Hargs = SmoothquantHostArgs; using Hargs = SmoothquantHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{ return Kargs{hargs.p_x,
hargs.p_x, hargs.p_xscale, hargs.p_yscale, hargs.p_qy, hargs.m, hargs.n, hargs.stride}; hargs.p_xscale,
hargs.p_yscale,
hargs.p_qy,
hargs.m,
hargs.n,
hargs.x_stride,
hargs.y_stride};
} }
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
...@@ -116,7 +124,7 @@ struct Smoothquant ...@@ -116,7 +124,7 @@ struct Smoothquant
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.x_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -157,7 +165,7 @@ struct Smoothquant ...@@ -157,7 +165,7 @@ struct Smoothquant
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<QYDataType*>(kargs.p_qy), static_cast<QYDataType*>(kargs.p_qy),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.y_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment