Commit ef5e60f6 authored by illsilin's avatar illsilin
Browse files

Merge branch 'develop' into gfx950

parents 2cc0fa26 5e93fa9e
...@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution ...@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution
{ {
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr); auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, bool_constant<oob_conditional_check>{}); load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor; return dst_tensor;
} }
template <typename DistributedTensor, bool oob_conditional_check = true> template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution ...@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution
}); });
} }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset] // [x0', x1', ... ] ==> [offset]
// also move window-origin // also move window-origin
......
...@@ -432,23 +432,38 @@ struct tile_window_linear ...@@ -432,23 +432,38 @@ struct tile_window_linear
CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>) CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>)
{ {
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{}); constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
// since this is linear offset, we assum bottom X tensor is always linear constexpr auto is_pure_linear_tensor =
constexpr index_t linear_offset = [&]() { reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{});
constexpr auto x_idx_ = linear_coord; if constexpr(is_pure_linear_tensor)
constexpr auto x_len_ = TileDstr{}.get_lengths(); {
static_assert(x_idx_.size() == x_len_.size()); // this case usually is a LDS window, everything is known at compile tile.
constexpr index_t x_dims_ = x_idx_.size(); // we directly use BottomTensorView transform to compute the offset, in case padding
index_t cu_stride_ = 1; auto bottom_tensor_coord =
index_t cu_offset_ = 0; make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
static_for<0, x_dims_, 1>{}([&](auto i_) { return bottom_tensor_coord.get_offset();
auto r_i_ = number<x_dims_ - i_ - 1>{}; }
cu_offset_ += x_idx_[r_i_] * cu_stride_; else
cu_stride_ *= x_len_[r_i_]; {
}); // this case usually is a global window, where last dim can be linear
return cu_offset_; // we hack here, that use the original TileDstr to compute the linear offset
}(); // ... hoping that there is no extra padding between other dims, which make sense
// since that would introduce runtime length (so can't use linear offset)
return linear_offset; constexpr index_t linear_offset = [&]() {
constexpr auto x_idx_ = linear_coord;
constexpr auto x_len_ = TileDstr{}.get_lengths();
static_assert(x_idx_.size() == x_len_.size());
constexpr index_t x_dims_ = x_idx_.size();
index_t cu_stride_ = 1;
index_t cu_offset_ = 0;
static_for<0, x_dims_, 1>{}([&](auto i_) {
auto r_i_ = number<x_dims_ - i_ - 1>{};
cu_offset_ += x_idx_[r_i_] * cu_stride_;
cu_stride_ *= x_len_[r_i_];
});
return cu_offset_;
}();
return linear_offset;
}
} }
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
...@@ -509,6 +524,64 @@ struct tile_window_linear ...@@ -509,6 +524,64 @@ struct tile_window_linear
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DstTile& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile, template <typename DstTile,
index_t i_access = -1, index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
...@@ -849,6 +922,58 @@ struct tile_window_linear ...@@ -849,6 +922,58 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE(); WINDOW_DISPATCH_ISSUE();
} }
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
};
WINDOW_DISPATCH_ISSUE();
}
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset] // [x0', x1', ... ] ==> [offset]
// also move window-origin // also move window-origin
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace ck_tile {
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
} // namespace ck_tile
...@@ -41,15 +41,65 @@ template <typename BottomTensorView_, ...@@ -41,15 +41,65 @@ template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
typename DataType_> typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_, update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{ {
tile_window.update(dstr_tensor); tile_window.update(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void
update_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.update_raw(dstr_tensor,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto update_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.update_raw(dstr_tensor,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace ck_tile {
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Context, index_t Start = 0, index_t Step = 1>
struct static_counter
{
public:
template <typename Unique>
static constexpr index_t next()
{
return next<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t next()
{
struct Unique
{
};
return next<Unique>(0) * Step + Start;
}
template <typename Unique>
static constexpr index_t current()
{
return current<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t current()
{
struct Unique
{
};
return current<Unique>(0) * Step + Start;
}
private:
template <index_t I>
struct slot
{
_Pragma("GCC diagnostic push");
_Pragma("GCC diagnostic ignored \"-Wundefined-internal\"");
friend constexpr bool slot_allocated(slot<I>);
_Pragma("GCC diagnostic pop");
};
template <index_t I>
struct allocate_slot
{
friend constexpr bool slot_allocated(slot<I>) { return true; }
enum
{
value = I
};
};
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = 0, bool = slot_allocated(slot<I>())>
static constexpr index_t next(index_t)
{
return next<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
// allocate_slot<I>.
template <typename Unique, index_t I = 0>
static constexpr index_t next(double)
{
return allocate_slot<I>::value;
}
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = Start, bool = slot_allocated(slot<I>())>
static constexpr index_t current(index_t)
{
return current<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will return the current counter, or assert
// in case next() hasn't been called yet.
template <typename Unique, index_t I = Start>
static constexpr index_t current(double)
{
static_assert(I != 0, "You must invoke next() first");
return I - 1;
}
};
namespace impl {
template <int I>
struct static_counter_uniq_;
}
#define MAKE_SC() \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
#define NEXT_SC(c_) c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
// Usage:
// constexpr auto c = MAKE_SC()
// NEXT_SC(c) // -> constexpr 0
// NEXT_SC(c) // -> constexpr 1
// NEXT_SC(c) // -> constexpr 2
} // namespace ck_tile
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck_tile/host/fill.hpp" #include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp" #include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_batched_dropout.hpp"
...@@ -20,6 +21,7 @@ ...@@ -20,6 +21,7 @@
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <stdint.h> #include <stdint.h>
#include <stdexcept> #include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile { namespace ck_tile {
template <typename T> template <typename T>
...@@ -36,6 +37,19 @@ struct DeviceMem ...@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf = nullptr; mpDeviceBuf = nullptr;
} }
} }
template <typename T>
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
{
if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
ToDevice(t.data());
}
void Realloc(std::size_t mem_size) void Realloc(std::size_t mem_size)
{ {
if(mpDeviceBuf) if(mpDeviceBuf)
...@@ -92,6 +106,27 @@ struct DeviceMem ...@@ -92,6 +106,27 @@ struct DeviceMem
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
} }
} }
// construct a host tensor with type T
template <typename T>
HostTensor<T> ToHost(std::size_t cpySize)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
HostTensor<T> h_({host_elements});
if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
return h_;
}
template <typename T>
HostTensor<T> ToHost()
{
return ToHost<T>(mMemSize);
}
void SetZero() const void SetZero() const
{ {
if(mpDeviceBuf) if(mpDeviceBuf)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <unordered_set> #include <unordered_set>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -22,13 +23,44 @@ struct FillUniformDistribution ...@@ -22,13 +23,44 @@ struct FillUniformDistribution
float a_{-5.f}; float a_{-5.f};
float b_{5.f}; float b_{5.f};
std::optional<uint32_t> seed_{11939}; std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
{ {
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); if(threaded)
std::uniform_real_distribution<float> dis(a_, b_); {
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); }); uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
} }
template <typename ForwardRange> template <typename ForwardRange>
...@@ -115,13 +147,44 @@ struct FillNormalDistribution ...@@ -115,13 +147,44 @@ struct FillNormalDistribution
float mean_{0.f}; float mean_{0.f};
float variance_{1.f}; float variance_{1.f};
std::optional<uint32_t> seed_{11939}; std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
{ {
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); if(threaded)
std::normal_distribution<float> dis(mean_, std::sqrt(variance_)); {
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); }); uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
} }
template <typename ForwardRange> template <typename ForwardRange>
...@@ -235,6 +298,44 @@ struct FillMonotonicSeq ...@@ -235,6 +298,44 @@ struct FillMonotonicSeq
} }
}; };
template <typename T, bool IsAscending = true>
struct FillStepRange
{
float start_value_{0};
float end_value_{3};
float step_{1};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, n = start_value_]() mutable {
auto tmp = n;
n += step_;
if constexpr(IsAscending)
{
if(n > end_value_)
n = start_value_;
}
else
{
if(n < end_value_)
n = start_value_;
}
return type_convert<T>(tmp);
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
template <typename T> template <typename T>
struct FillConstant struct FillConstant
{ {
......
...@@ -8,12 +8,13 @@ ...@@ -8,12 +8,13 @@
#include <iostream> #include <iostream>
#include <iomanip> #include <iomanip>
#include <numeric> #include <numeric>
#include <thread>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <fstream>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/ranges.hpp" #include "ck_tile/host/ranges.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old ...@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
return HostTensorDescriptor(new_lengths, new_strides); return HostTensorDescriptor(new_lengths, new_strides);
} }
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <typename F, typename... Xs> template <typename F, typename... Xs>
struct ParallelTensorFunctor struct ParallelTensorFunctor
{ {
...@@ -590,6 +574,107 @@ struct HostTensor ...@@ -590,6 +574,107 @@ struct HostTensor
size() * FromSize / ToSize}; size() * FromSize / ToSize};
} }
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
{
os << t.mDesc;
os << "[";
for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
{
os << type_convert<float>(t.mData[idx]) << " #### ";
}
else
{
os << t.mData[idx];
}
}
os << "]";
return os;
}
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void loadtxt(std::string file_name, std::string dtype = "float")
{
std::ifstream file(file_name);
if(file.is_open())
{
std::string line;
index_t cnt = 0;
while(std::getline(file, line))
{
if(cnt >= static_cast<index_t>(mData.size()))
{
throw std::runtime_error(std::string("data read from file:") + file_name +
" is too big");
}
if(dtype == "float")
{
mData[cnt] = type_convert<T>(std::stof(line));
}
else if(dtype == "int" || dtype == "int32")
{
mData[cnt] = type_convert<T>(std::stoi(line));
}
cnt++;
}
file.close();
if(cnt < static_cast<index_t>(mData.size()))
{
std::cerr << "Warning! reading from file:" << file_name
<< ", does not match the size of this tensor" << std::endl;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void savetxt(std::string file_name, std::string dtype = "float")
{
std::ofstream file(file_name);
if(file.is_open())
{
for(auto& itm : mData)
{
if(dtype == "float")
file << type_convert<float>(itm) << std::endl;
else if(dtype == "int")
file << type_convert<int>(itm) << std::endl;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
file << type_convert<float>(itm) << std::endl;
}
file.close();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
Descriptor mDesc; Descriptor mDesc;
Data mData; Data mData;
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <utility>
namespace ck_tile {
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
// -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
///
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
template <typename AccDataType, // you only need to explcitly set this one
typename Activation, // ck_tile::element_wise::Gelu
typename ADataType,
typename GDataType,
typename DDataType,
typename ODataType,
typename AScaleDataType,
typename GScaleDataType,
typename DScaleDataType,
typename YSmoothScaleDataType,
typename TopkWeightDataType,
typename IndexDataType>
void reference_fused_moe(
const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<IndexDataType>&
sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
const ck_tile::HostTensor<IndexDataType>&
token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
ck_tile::index_t block_m,
ck_tile::index_t tokens,
ck_tile::index_t experts,
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up
ck_tile::index_t topk,
ck_tile::index_t gate_only)
{
assert(sorted_token_ids_host.get_num_of_dimension() == 1);
assert(sorted_weight_host.get_num_of_dimension() == 1);
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
assert(num_sorted_tiles_host.get_element_size() == 1);
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
ck_tile::index_t intermediate_size_0 = intermediate_size;
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
// assert();
auto f = [&](auto i_flatten) {
ck_tile::index_t i_tile = i_flatten / block_m;
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens)
return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
{
acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
}
acc_0(0, i_n) = acc;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
if(gate_only)
{
if(intermediate_size_1 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
Activation{}(y(0, i_n), acc_0(0, i_n));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
else
{
if(intermediate_size_1 * 2 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
AccDataType tmp;
Activation{}(tmp, acc_0(0, i_n));
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
}
}
// second gemm, loop along gemm-n
ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
{
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
}
acc_1(0, i_n) = acc * weight; // multiple weight here
}
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
}
};
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
// reduce
auto r = [&](auto i_token) {
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = type_convert<AccDataType>(0);
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
{
acc += out_topk_tokens(i_token, i_topk, i_n);
}
o_host(i_token, i_n) = type_convert<ODataType>(acc);
}
};
make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
(void)num_sorted_tiles_host;
(void)sa_host;
(void)sg_host;
(void)sd_host;
(void)sy_host;
}
} // namespace ck_tile
...@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device,
return; return;
} }
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(d_A,
a_device.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(d_B,
b_device.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
{
ADataType* d_ATemp = d_A + batch_id * batch_stride_A;
BDataType* d_BTemp = d_B + batch_id * batch_stride_B;
CDataType* d_CTemp = d_C + batch_id * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
errC = hipMemcpy(c_device.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
namespace ck_tile { namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
template <typename WeightType, typename IndexType = index_t> template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights, const HostTensor<WeightType>& weights,
...@@ -20,8 +23,14 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -20,8 +23,14 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
{ {
const index_t num_token = topk_ids.mDesc.get_lengths()[0]; const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1]; const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts, // allocate a temp buffer, and fill the value with [number_token|topk]
std::vector<IndexType>(unit_size, num_token)); std::vector<std::vector<IndexType>> expert_tokens(
experts,
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
std::vector<IndexType>(unit_size, MOE_SORTING_MOCK_ID(num_token, topk)));
#else
std::vector<IndexType>(unit_size, num_token));
#endif
std::vector<std::vector<WeightType>> expert_token_weights( std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0)); experts, std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1); std::vector<IndexType> expert_slices(experts, 1);
...@@ -42,12 +51,19 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -42,12 +51,19 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
expert_token_weights[e].resize(new_size); expert_token_weights[e].resize(new_size);
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++) for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
{ {
expert_tokens[e][i] = num_token; #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens[e][i] = MOE_SORTING_MOCK_ID(num_token, topk);
#else
expert_tokens[e][i] = num_token;
#endif
expert_token_weights[e][i] = 0; expert_token_weights[e][i] = 0;
} }
} }
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens[e][idx] = t; expert_tokens[e][idx] = MOE_SORTING_MOCK_ID(t, k);
#else
expert_tokens[e][idx] = t;
#endif
expert_token_weights[e][idx] = w; expert_token_weights[e][idx] = w;
expert_slice_idxs[e]++; expert_slice_idxs[e]++;
} }
...@@ -75,4 +91,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -75,4 +91,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
unit_cnt *= unit_size; unit_cnt *= unit_size;
return; return;
} }
#undef MOE_SORTING_MOCK_ID
} // namespace ck_tile } // namespace ck_tile
...@@ -16,7 +16,7 @@ namespace ck_tile { ...@@ -16,7 +16,7 @@ namespace ck_tile {
*/ */
template <typename DataType> template <typename DataType>
CK_TILE_HOST void CK_TILE_HOST void
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims) reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
{ {
const auto x_len = x.mDesc.get_lengths(); const auto x_len = x.mDesc.get_lengths();
const auto y_len = y.mDesc.get_lengths(); const auto y_len = y.mDesc.get_lengths();
...@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v ...@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std::vector<size_t> tmp(rank, 0); std::vector<size_t> tmp(rank, 0);
for(index_t i = 0; i < rank; i++) for(index_t i = 0; i < rank; i++)
{ {
tmp[dims[i]] = y_coord[i]; tmp[perm[i]] = y_coord[i];
} }
return tmp; return tmp;
}(); }();
...@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v ...@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
} }
template <typename DataType>
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
{
auto x_shape = x.get_lengths();
ck_tile::index_t rank = perm.size();
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
for(int i = 0; i < static_cast<int>(rank); i++)
{
tmp[i] = x_shape[perm[i]];
}
return tmp;
}();
HostTensor<DataType> y(y_shape);
reference_permute(x, y, perm);
return y;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass ...@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
static constexpr bool kSaveX = Problem::kSaveX; static constexpr bool kSaveX = Problem::kSaveX;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass ...@@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{}; auto reduce_sum_func = ReduceOp::Add{};
auto reduce_absmax_func = ReduceOp::AbsMax{}; auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_max_func = ReduceOp::Max{}; auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>(); float rtn;
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>(); asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
...@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass ...@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
}); });
// compute absmax, each-thread->cross-lane->cross-warp // compute absmax, each-thread->cross-lane->cross-warp
auto absmax = block_reduce2d( auto absmax = [&]() {
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func); constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
{
return block_reduce2d(y,
reduce_absmax_func.GetIdentityValue<ComputeDataType>(),
reduce_absmax3_func,
sequence<1, 2>{});
}
else
{
return block_reduce2d(
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
}
}();
block_reduce2d_sync(absmax, reduce_max_func); block_reduce2d_sync(absmax, reduce_max_func);
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
......
...@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass ...@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
static constexpr bool kSaveX = Problem::kSaveX; static constexpr bool kSaveX = Problem::kSaveX;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass ...@@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{}; auto reduce_sum_func = ReduceOp::Add{};
auto reduce_absmax_func = ReduceOp::AbsMax{}; auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_max_func = ReduceOp::Max{}; auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>(); float rtn;
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>(); asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
...@@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass ...@@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
y(idx) = type_convert<ComputeDataType>(y_); y(idx) = type_convert<ComputeDataType>(y_);
}); });
block_reduce2d(y, absmax, reduce_absmax_func); constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{});
else
block_reduce2d(y, absmax, reduce_absmax_func);
if constexpr(kSaveX) if constexpr(kSaveX)
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
......
...@@ -572,6 +572,105 @@ struct FastGelu ...@@ -572,6 +572,105 @@ struct FastGelu
} }
}; };
struct FastGeluAsm
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp;
asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"s_nop 0 ; hazard for exp\n"
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
"s_nop 0 ; hazard for rcp \n"
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
: [v_y] "=v"(y), [v_tmp] "+v"(tmp)
: [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
:);
}
template <>
CK_TILE_HOST void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u0 = x.x * (c1 * x.x * x.x + c2);
const float emu0 = exp(u0);
y.x = x.x / (1.f + emu0);
const float u1 = x.y * (c1 * x.y * x.y + c2);
const float emu1 = exp(u1);
y.y = x.y / (1.f + emu1);
}
// this is packed verion to remove data hazard for trans
template <>
CK_TILE_DEVICE void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
float c2 = -2.0 * 0.797885f;
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp0, tmp1;
float y0 = x.x, y1 = x.y;
asm volatile(
"v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x\n"
"v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x\n"
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n"
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n"
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n"
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)\n"
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)\n"
: [v_y0] "+v"(y0),
[v_y1] "+v"(y1),
[v_c2] "+v"(c2),
// NOTE! it is totally possible that c2/y0/y1 share same register, they are all local
// tmp variables we need to expicitly hint compiler they may read+write, to allow
// allocate different register , the side effect is c2=** may issue for every such
// inline asm block
[v_tmp0] "+v"(tmp0),
[v_tmp1] "+v"(tmp1)
: [s_c1] "s"(c1), [s_log2e] "s"(log2e_)
:);
y.x = y0;
y.y = y1;
}
};
// https://paperswithcode.com/method/gelu // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2))) // y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu struct Gelu
......
...@@ -3,9 +3,8 @@ ...@@ -3,9 +3,8 @@
#pragma once #pragma once
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" #include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace ck_tile {
// A async load to LDS, B direct to AGPR
// B matrix preshuffled in br*kr*w
// require 4 wave, occupancy=1c
// agpr useage:256
// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
//
// for this gemm, 4 16x16x16 transposed layout
// input A vpgpr layout
// v0-v15: [ 0:15](gemm_m)x128(gemm_k)
// v16-v31: [16:31](gemm_m)x128(gemm_k)
// input B vpgpr layout
// v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
// v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
// ......................
// v111-v127: [448:463](gemm_n)x128(gemm_k)
// output C vpgpr layout
// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
// ......................
// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
// v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 512;
static constexpr index_t Block_K = 128;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t NumWarps = 4;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 32; // 16 * SubKPacks
static constexpr index_t BlockSize = 256;
static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider SubKPacks
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{
using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDist();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
// template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{
// load from LDS to register, every wave has same layout
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KPad = KPack_; // pad between warps
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
static_assert(KPack_ == (kABKPerLane * kKIter));
constexpr auto lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<Repeat_M>{}, // m0 y
number<kAMLane>{}, // m1 p
number<Repeat_K>{}, // k0 y
number<kABKLane>{}, // k1 p
number<KPack_>{}), // k2 y-vector
make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
number<Block_K + KPad>{}, // m1
number<kABKLane * KPack_>{}, // k0
number<KPack_>{}, // k1
number<1>{}), // k2
number<KPack_>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
make_merge_transform(
make_tuple(number<Repeat_K>{}, number<kABKLane>{}, number<KPack_>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
static constexpr auto GetGemm_AWarpEnc()
{
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
using enc_ = tile_distribution_encoding<
sequence<>,
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
return enc_{};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 32 * (128 + 8) * sizeof(bf16_t);
}
};
struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using ADataType = bf16_t;
using BDataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
MakeLdsStoreDesc_A().get_lengths(),
{0, 0, 0});
auto a_sld = [&]() {
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<WarpPerBlock_N>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
MakeLdsLoadDesc_A().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
},
number<a_sld.get_num_of_access()>{});
index_t loop_cnt = k / Block_K;
// this is the acc thread buffer
fp32x4_t v_acc[16]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
: [s_loop_cnt]"+s"(loop_cnt),
[v_acc_0]"+v"(v_acc[0]),
[v_acc_1]"+v"(v_acc[1]),
[v_acc_2]"+v"(v_acc[2]),
[v_acc_3]"+v"(v_acc[3]),
[v_acc_4]"+v"(v_acc[4]),
[v_acc_5]"+v"(v_acc[5]),
[v_acc_6]"+v"(v_acc[6]),
[v_acc_7]"+v"(v_acc[7]),
[v_acc_8]"+v"(v_acc[8]),
[v_acc_9]"+v"(v_acc[9]),
[v_acc_10]"+v"(v_acc[10]),
[v_acc_11]"+v"(v_acc[11]),
[v_acc_12]"+v"(v_acc[12]),
[v_acc_13]"+v"(v_acc[13]),
[v_acc_14]"+v"(v_acc[14]),
[v_acc_15]"+v"(v_acc[15]),
[s_mem_]"+r"(smem)
: [s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
[s_res_a3]"s"(res_a[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
[s_m0_init]"s"(m0_init_value),
[s_size_per_issue]"s"(size_per_issue),
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
[sld_os_0]"n"(sld_os[number<0>{}].value),
[sld_os_1]"n"(sld_os[number<1>{}].value),
[sld_os_2]"n"(sld_os[number<2>{}].value),
[sld_os_3]"n"(sld_os[number<3>{}].value),
[sld_os_4]"n"(sld_os[number<4>{}].value),
[sld_os_5]"n"(sld_os[number<5>{}].value),
[sld_os_6]"n"(sld_os[number<6>{}].value),
[sld_os_7]"n"(sld_os[number<7>{}].value),
[s_tile_os_a]"s"(tile_offset_a_bytes),
[s_tile_os_b]"s"(tile_offset_b_bytes)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s86", // s86 as tmp
"v64", "v65", "v66", "v67", "v68", "v69",
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
"v124", "v125", "v126", "v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = MakeCBlockTile();
for(auto i = 0; i < 16; i++)
{
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
return c;
}
};
struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using ADataType = fp16_t;
using BDataType = fp16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
MakeLdsStoreDesc_A().get_lengths(),
{0, 0, 0});
auto a_sld = [&]() {
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<WarpPerBlock_N>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
MakeLdsLoadDesc_A().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
},
number<a_sld.get_num_of_access()>{});
index_t loop_cnt = k / Block_K;
// this is the acc thread buffer
fp32x4_t v_acc[16]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
: [s_loop_cnt]"+s"(loop_cnt),
[v_acc_0]"+v"(v_acc[0]),
[v_acc_1]"+v"(v_acc[1]),
[v_acc_2]"+v"(v_acc[2]),
[v_acc_3]"+v"(v_acc[3]),
[v_acc_4]"+v"(v_acc[4]),
[v_acc_5]"+v"(v_acc[5]),
[v_acc_6]"+v"(v_acc[6]),
[v_acc_7]"+v"(v_acc[7]),
[v_acc_8]"+v"(v_acc[8]),
[v_acc_9]"+v"(v_acc[9]),
[v_acc_10]"+v"(v_acc[10]),
[v_acc_11]"+v"(v_acc[11]),
[v_acc_12]"+v"(v_acc[12]),
[v_acc_13]"+v"(v_acc[13]),
[v_acc_14]"+v"(v_acc[14]),
[v_acc_15]"+v"(v_acc[15]),
[s_mem_]"+r"(smem)
: [s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
[s_res_a3]"s"(res_a[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
[s_m0_init]"s"(m0_init_value),
[s_size_per_issue]"s"(size_per_issue),
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
[sld_os_0]"n"(sld_os[number<0>{}].value),
[sld_os_1]"n"(sld_os[number<1>{}].value),
[sld_os_2]"n"(sld_os[number<2>{}].value),
[sld_os_3]"n"(sld_os[number<3>{}].value),
[sld_os_4]"n"(sld_os[number<4>{}].value),
[sld_os_5]"n"(sld_os[number<5>{}].value),
[sld_os_6]"n"(sld_os[number<6>{}].value),
[sld_os_7]"n"(sld_os[number<7>{}].value),
[s_tile_os_a]"s"(tile_offset_a_bytes),
[s_tile_os_b]"s"(tile_offset_b_bytes)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s86", // s86 as tmp
"v64", "v65", "v66", "v67", "v68", "v69",
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
"v124", "v125", "v126", "v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = MakeCBlockTile();
for(auto i = 0; i < 16; i++)
{
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
return c;
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment