Commit aa30ef56 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

resolved conflicts

parents 3cad16c4 5d671a5f
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
......
...@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// using WarpTile = ck_tile::sequence<1, 512>; // using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>; // using Vector = ck_tile::sequence<1, 8>;
constexpr ck_tile::index_t kBlockSize = 512; constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
std::cout << "grid size " << kGridSize << std::endl; std::cout << "grid size " << kGridSize << std::endl;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -1558,14 +1558,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1558,14 +1558,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
} }
} }
if(!(arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 && const bool is_w_pad_zero = arg.input_left_pads_[NDimSpatial - 1] == 0 &&
arg.input_right_pads_[NDimSpatial - 1] == 0;
const auto X = arg.filter_spatial_lengths_[NDimSpatial - 1];
const bool XC_access_allowed = arg.Conv_G_ == 1 &&
(arg.Conv_C_ * X) % BBlockTransferSrcScalarPerVector == 0 &&
is_w_pad_zero;
if(!((arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 || XC_access_allowed) &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0)) arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0))
{ {
if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1)) if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1 &&
NumGroupsToMerge > 1))
{ {
return false; return false;
} }
if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1)) if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1 &&
NumGroupsToMerge > 1))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{ {
return false; return false;
} }
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
{
return false;
}
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>()) if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -53,7 +53,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -53,7 +53,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I10 = Number<10>{};
static constexpr auto I12 = Number<12>{};
static constexpr auto I13 = Number<13>{};
static constexpr auto I14 = Number<14>{};
static constexpr auto I16 = Number<16>{};
static constexpr index_t PackedSize = []() { static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>) if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
...@@ -198,9 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -198,9 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_oob_thread_scratch_tuple_(thread_scratch_id) src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid); .template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v; dst_vector_type op_r_v;
...@@ -234,14 +244,63 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -234,14 +244,63 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type; using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type; using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
auto src_vector_container = src_vector_type{ using VectorSizeLookupTable = Tuple<Sequence<>,
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)}; Sequence<I1>,
Sequence<I2>,
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { Sequence<I2, I1>,
// apply the src elementwise op and convert to DstData under the hood if needed Sequence<I4>,
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx), Sequence<I4, I1>,
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]); Sequence<I4, I2>,
}); Sequence<I4, I2, I1>,
Sequence<I8>,
Sequence<I8, I1>,
Sequence<I8, I2>,
Sequence<I8, I2, I1>,
Sequence<I8, I4>,
Sequence<I8, I4, I1>,
Sequence<I8, I4, I2>,
Sequence<I8, I4, I2, I1>,
Sequence<I16>>;
using VectorOffsetsLookupTable = Tuple<Sequence<>,
Sequence<I0>,
Sequence<I0>,
Sequence<I0, I2>,
Sequence<I0>,
Sequence<I0, I4>,
Sequence<I0, I4>,
Sequence<I0, I4, I6>,
Sequence<I0>,
Sequence<I0, I8>,
Sequence<I0, I8>,
Sequence<I0, I8, I10>,
Sequence<I0, I8>,
Sequence<I0, I8, I12>,
Sequence<I0, I8, I12>,
Sequence<I0, I8, I12, I14>,
Sequence<I0>>;
static_for<0, tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::Size(), 1>{}(
[&](auto v_idx) {
constexpr auto VectorLoadSize =
tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::At(v_idx);
constexpr auto LoadOffset =
tuple_element_t<SrcScalarPerVector, VectorOffsetsLookupTable>::At(v_idx);
using src_vector_container = vector_type_maker_t<SrcData, VectorLoadSize>;
using src_vector_container_t = typename src_vector_container::type;
src_vector_container src_vector =
src_vector_container{src_buf.template Get<src_vector_container_t>(
src_coord_.GetOffset() / PackedSize + LoadOffset, true)};
static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if
// needed
src_element_op_(
op_r_v.template AsType<dst_elem_op_vec_t>()(idx + LoadOffset),
src_vector.template AsType<src_elem_op_vec_t>()[idx]);
});
});
// copy data from src_vector_container into src_thread_scratch_ // copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -314,6 +314,76 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>> ...@@ -314,6 +314,76 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
} }
}; };
template <typename T>
struct vector_type<T, 3, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d3_t __attribute__((ext_vector_type(3)));
using type = d3_t;
union
{
d3_t d3_;
StaticallyIndexedArray<d1_t, 3> d1x3_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
StaticallyIndexedArray<d3_t, 1> d3x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d3_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x3_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d3x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d3_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x3_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d3x1_;
}
else
{
return err;
}
}
};
template <typename T> template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
{ {
...@@ -384,6 +454,158 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>> ...@@ -384,6 +454,158 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
} }
}; };
template <typename T>
struct vector_type<T, 5, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d5_t __attribute__((ext_vector_type(5)));
using type = d5_t;
union
{
d5_t d5_;
StaticallyIndexedArray<d1_t, 5> d1x5_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
StaticallyIndexedArray<d5_t, 1> d5x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value || is_same<X, d5_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x5_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else if constexpr(is_same<X, d5_t>::value)
{
return data_.d5x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value || is_same<X, d5_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x5_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else if constexpr(is_same<X, d5_t>::value)
{
return data_.d5x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 7, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d7_t __attribute__((ext_vector_type(7)));
using type = d7_t;
union
{
d7_t d7_;
StaticallyIndexedArray<d1_t, 7> d1x7_;
StaticallyIndexedArray<d2_t, 3> d2x3_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
StaticallyIndexedArray<d7_t, 1> d7x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d7_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x7_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x3_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else if constexpr(is_same<X, d7_t>::value)
{
return data_.d7x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d7_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x7_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x3_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else if constexpr(is_same<X, d7_t>::value)
{
return data_.d7x1_;
}
else
{
return err;
}
}
};
template <typename T> template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
{ {
...@@ -466,6 +688,88 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>> ...@@ -466,6 +688,88 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
} }
}; };
template <typename T>
struct vector_type<T, 13, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d13_t __attribute__((ext_vector_type(13)));
using type = d13_t;
union
{
d13_t d13_;
StaticallyIndexedArray<d1_t, 13> d1x13_;
StaticallyIndexedArray<d4_t, 3> d4x3_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
StaticallyIndexedArray<d13_t, 1> d13x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d13_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x13_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x3_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else if constexpr(is_same<X, d13_t>::value)
{
return data_.d13x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d13_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x13_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x3_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else if constexpr(is_same<X, d13_t>::value)
{
return data_.d13x1_;
}
else
{
return err;
}
}
};
template <typename T> template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
{ {
......
...@@ -106,11 +106,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -106,11 +106,6 @@ struct BlockFmhaPipelineQSKSVS
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return Policy::template GetSmemSizeQ<Problem>();
}
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
...@@ -328,8 +323,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -328,8 +323,7 @@ struct BlockFmhaPipelineQSKSVS
}); });
} }
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile { // tail
{ // tail
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, q_lds_window, k_lds_window); gemm_0(s_acc, q_lds_window, k_lds_window);
block_sync_lds(); block_sync_lds();
...@@ -341,6 +335,10 @@ struct BlockFmhaPipelineQSKSVS ...@@ -341,6 +335,10 @@ struct BlockFmhaPipelineQSKSVS
gemm_0(s_acc, q_lds_window, k_lds_window); gemm_0(s_acc, q_lds_window, k_lds_window);
} }
__builtin_amdgcn_sched_barrier(0);
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -462,6 +460,12 @@ struct BlockFmhaPipelineQSKSVS ...@@ -462,6 +460,12 @@ struct BlockFmhaPipelineQSKSVS
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{}); block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
__builtin_amdgcn_sched_barrier(0);
// l{j}, Oacc{j} // l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
...@@ -509,9 +513,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -509,9 +513,6 @@ struct BlockFmhaPipelineQSKSVS
} }
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
......
...@@ -9,11 +9,33 @@ ...@@ -9,11 +9,33 @@
namespace ck_tile { namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
using BlockFmhaPipelineQSKSVSDefaultPolicy = struct BlockFmhaPipelineQSKSVSDefaultPolicy
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
/* AsyncCopyK = */ false, /* AsyncCopyK = */ false,
/* AsyncCopyV = */ false, /* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1, /* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
} // namespace ck_tile
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>() + GetSmemSizeK<Problem>(), GetSmemSizeV<Problem>()) +
GetSmemSizeDropout<Problem>();
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -146,8 +146,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -146,8 +146,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; constexpr index_t kBlockSize = Problem::kBlockSize;
return 16 / sizeof(QDataType); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
} }
template <typename Problem> template <typename Problem>
...@@ -156,19 +164,25 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -156,19 +164,25 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t M1 = kBlockSize / get_warp_size(); static_assert(0 < ElemPerThread);
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, sequence<1, 2>,
...@@ -215,18 +229,31 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -215,18 +229,31 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> && std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> && std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> && std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
...@@ -21,35 +21,20 @@ struct BlockGemmARegBRegCRegV1 ...@@ -21,35 +21,20 @@ struct BlockGemmARegBRegCRegV1
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
// C += A * B static constexpr index_t NPerBlock = BlockGemmShape::kN;
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor> static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
const ABlockTensor& a_block_tensor, using WG = remove_cvref_t<decltype(config.template at<0>())>;
const BBlockTensor& b_block_tensor) const static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{ {
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// M->N Warp
constexpr auto a_block_outer_dstr_encoding = constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>, tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
...@@ -57,7 +42,14 @@ struct BlockGemmARegBRegCRegV1 ...@@ -57,7 +42,14 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<1, 0>>, tuple<sequence<1, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr auto b_block_outer_dstr_encoding = constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>, tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>, tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
...@@ -65,7 +57,14 @@ struct BlockGemmARegBRegCRegV1 ...@@ -65,7 +57,14 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<0, 1>>, tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
...@@ -73,15 +72,28 @@ struct BlockGemmARegBRegCRegV1 ...@@ -73,15 +72,28 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<1, 1>>, tuple<sequence<1, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( return c_block_dstr_encode;
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); }
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( // C += A * B
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode();
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = MakeBBlockDistributionEncode();
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode();
// check ABC-block-distribution // check ABC-block-distribution
static_assert( static_assert(
...@@ -159,20 +171,6 @@ struct BlockGemmARegBRegCRegV1 ...@@ -159,20 +171,6 @@ struct BlockGemmARegBRegCRegV1
CK_TILE_DEVICE static constexpr auto MakeCBlockTile() CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{ {
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
......
...@@ -104,9 +104,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -104,9 +104,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using CLayout = remove_cvref_t<typename Problem::CLayout>; using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>; using I0 = number<0>;
using I2 = number<2>; using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
......
...@@ -23,6 +23,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -23,6 +23,8 @@ struct GemmPipelineAGmemBGmemCRegV1
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>; using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kMPerBlock = BlockGemmShape::kM;
...@@ -126,7 +128,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -126,7 +128,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM // Block GEMM
auto block_gemm = Policy::template GetBlockGemm<Problem>(); auto block_gemm = BlockGemm();
// Acc register tile // Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
......
...@@ -12,8 +12,11 @@ namespace ck_tile { ...@@ -12,8 +12,11 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = false; static constexpr bool TransposeC = true;
#if 0 #if 0
// 2d // 2d
...@@ -491,10 +494,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -491,10 +494,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using AccDataType = float; using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpTile = typename Problem::BlockGemmShape::WarpTile;
......
...@@ -11,7 +11,6 @@ namespace ck_tile { ...@@ -11,7 +11,6 @@ namespace ck_tile {
// UniversalGemm Policy // UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy struct UniversalGemmPipelineAgBgCrPolicy
{ {
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...@@ -75,6 +75,28 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances ...@@ -75,6 +75,28 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances
// clang-format on // clang-format on
>; >;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_irregular_instances =
std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 48, 64, 32, 8, 16, 16, 3, 4, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 3, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 48, 32, 8, 16, 16, 4, 3, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 3, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 208, 32, 8, 16, 16, 4, 13, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 13, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>
// clang-format on
>;
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
...@@ -118,6 +140,28 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance ...@@ -118,6 +140,28 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance
// clang-format on // clang-format on
>; >;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_irregular_instances =
std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 48, 64, 32, 8, 16, 16, 3, 4, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 3, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 48, 32, 8, 16, 16, 4, 3, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 3, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 208, 32, 8, 16, 16, 4, 13, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 13, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>
// clang-format on
>;
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -358,6 +358,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -358,6 +358,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs); op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances( add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instances(
op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
...@@ -383,6 +387,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -383,6 +387,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs); op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances( add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances(
op_ptrs);
} }
#endif #endif
} }
...@@ -478,6 +486,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -478,6 +486,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs); op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances( add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instances(
op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
...@@ -503,6 +515,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -503,6 +515,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs); op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances( add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instances(
op_ptrs);
} }
#endif #endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -149,6 +149,30 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p ...@@ -149,6 +149,30 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances( void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW, NGCHW,
...@@ -234,6 +258,30 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi ...@@ -234,6 +258,30 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances( void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW, NGCHW,
...@@ -384,6 +432,30 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1 ...@@ -384,6 +432,30 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances( void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW, NGCDHW,
...@@ -469,6 +541,30 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16 ...@@ -469,6 +541,30 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances( void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW, NGCDHW,
......
...@@ -19,6 +19,10 @@ set(GROUPED_CONV2D_BWD_WEIGHT ...@@ -19,6 +19,10 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp
) )
if(DL_KERNELS) if(DL_KERNELS)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_irregular_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_irregular_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment