Unverified Commit 06701e70 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents 5800d24e da42a889
...@@ -36,30 +36,37 @@ template <typename T, ...@@ -36,30 +36,37 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}); tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord> index_t NumCoord,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile, async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window) NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
return tile_window.async_load(lds_tile); return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
...@@ -35,6 +35,8 @@ struct null_tile_window ...@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
CK_TILE_DEVICE void init_raw() {}
WindowLengths window_lengths_; WindowLengths window_lengths_;
}; };
......
...@@ -36,6 +36,8 @@ struct tensor_view ...@@ -36,6 +36,8 @@ struct tensor_view
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
...@@ -85,30 +87,34 @@ struct tensor_view ...@@ -85,30 +87,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X // "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE void CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
get_vectorized_elements_raw(remove_cvref_t<X>& dst, const TensorCoord& coord,
const TensorCoord& coord, bool_constant<oob_conditional_check> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<pre_nop> = {}) const
{ {
return buf_.template get_raw<X, oob_conditional_check>( return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst, dst,
coord.get_offset(), coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
} }
template <typename X, template <typename X,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem, CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
const TensorCoord& coord) const remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
{ {
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/); return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
} }
// X is vector of DataType. // X is vector of DataType.
......
...@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&) ...@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor... // sub-dword tensor...
template <typename DstrTensors, index_t v> template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>) CK_TILE_DEVICE void
set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
{ {
constexpr index_t tensor_bytes = using elem_type = typename DstrTensors::DataType;
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); constexpr index_t elem_size = sizeof(elem_type);
if constexpr(v == 0 && tensor_bytes % 4 == 0)
constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
// # bytes per write = 4
if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
{ {
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto& buffer = dstr_tensor.get_thread_buffer();
static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
if constexpr(elem_size == 1)
{
// # elements per write = 4
constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
buffer[i_write * 4 + 0] = values.x;
buffer[i_write * 4 + 1] = values.y;
buffer[i_write * 4 + 2] = values.z;
buffer[i_write * 4 + 3] = values.w;
}
else if constexpr(elem_size == 2)
{
// # elements per write = 2
constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
buffer[i_write * 2 + 0] = values.x;
buffer[i_write * 2 + 1] = values.y;
}
else if constexpr(elem_size == 4)
{
// # elements per write = 1
constexpr elem_type value = 0;
buffer[i_write] = value;
}
else
{
static_assert(false, "type not supported");
}
});
#else
using dvec_t = array<index_t, tensor_bytes / 4>; using dvec_t = array<index_t, tensor_bytes / 4>;
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer()); auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
for(auto i = 0; i < tensor.size(); i++) for(auto i = 0; i < tensor.size(); i++)
tensor.get(i) = v; tensor.get(i) = v;
#endif
} }
else else
{ {
tile_elementwise_inout( tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); }, dstr_tensor);
dstr_tensor);
} }
} }
......
...@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution ...@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, bool oob_conditional_check = true> template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution ...@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
...@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution ...@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>( get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(), dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord, bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{},
pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution ...@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution
} }
}); });
}); });
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
} }
// TODO: currently async load only implemented in inline asm // TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true> template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{ {
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>; using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView; // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
...@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution ...@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// read from bottom tensor // read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>( get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord); smem, bottom_tensor_thread_coord, pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution ...@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution
}); });
} }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view // this is the bottom tensor view
// [x0', x1', ...] ==> [offset] // [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_; BottomTensorView bottom_tensor_view_;
......
...@@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
// minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)
{
return 1;
}
if constexpr(kK0BlockLength <= 32) if constexpr(kK0BlockLength <= 32)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
...@@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm // TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well // a side effect is we have to use inline asm for q as well
...@@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window.get_window_origin(), k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = [&]() {
if constexpr(kPadSeqLenK &&
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)))
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window( auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeVDramTileDistribution<Problem>()); Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile // prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}), async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window); k_dram_window,
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
...@@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
// move K tile windows // move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0}); move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window = k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
if constexpr(k1_loops >= 2 && if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{})) LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
} }
// tail // tail
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -133,5 +133,40 @@ struct FillConstant ...@@ -133,5 +133,40 @@ struct FillConstant
} }
}; };
template <typename T>
struct TransformIntoStructuralSparsity
{
// clang-format off
static constexpr T valid_sequences[] = {
0, 0, 1, 1,
0, 1, 0, 1,
0, 1, 1, 0,
1, 0, 0, 1,
1, 0, 1, 0,
1, 1, 0, 0,
};
// clang-format on
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::for_each(first, last, [=, idx = 0](T& elem) mutable {
auto tmp_idx = idx;
idx += 1;
return elem *= valid_sequences[tmp_idx % (sizeof(valid_sequences) / sizeof(T))];
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const TransformIntoStructuralSparsity&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
...@@ -43,7 +43,15 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) ...@@ -43,7 +43,15 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
first = false; first = false;
else else
os << delim; os << delim;
os << static_cast<T>(v);
if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>)
{
os << ck::type_convert<float>(v);
}
else
{
os << static_cast<T>(v);
}
} }
return os; return os;
} }
......
...@@ -59,7 +59,7 @@ function(add_instance_library INSTANCE_NAME) ...@@ -59,7 +59,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach() endforeach()
# Do not build WMMA instances if gfx11 targets are not on the target list # Do not build WMMA instances if gfx11 targets are not on the target list
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma instance ${source} ") message("removing wmma instance ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
...@@ -177,7 +177,7 @@ FOREACH(subdir_path ${dir_list}) ...@@ -177,7 +177,7 @@ FOREACH(subdir_path ${dir_list})
message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11")) if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12"))
message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
...@@ -185,11 +185,11 @@ FOREACH(subdir_path ${dir_list}) ...@@ -185,11 +185,11 @@ FOREACH(subdir_path ${dir_list})
message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9")) if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9"))
message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
......
# ONLY XDL_KERNELS
set(GEMM_UNIVERSAL_STREAMK_INSTANCES)
list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp)
add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment