Unverified Commit 3c5717df authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into gemm_elementwise_gemm

parents 171b9030 d9f1ead3
......@@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
#pragma clang diagnostic pop
}
template <typename CompareTo, typename... Rest>
struct is_any_of : std::false_type
{
};
template <typename CompareTo, typename FirstType>
struct is_any_of<CompareTo, FirstType> : std::is_same<CompareTo, FirstType>
{
};
template <typename CompareTo, typename FirstType, typename... Rest>
struct is_any_of<CompareTo, FirstType, Rest...>
: std::integral_constant<bool,
std::is_same<CompareTo, FirstType>::value ||
is_any_of<CompareTo, Rest...>::value>
{
};
} // namespace ck_tile
......@@ -51,16 +51,18 @@ struct composes<F>
template <typename... Ts>
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
template <typename To>
template <typename SaturateType>
struct saturates
{
template <typename From>
CK_TILE_HOST_DEVICE constexpr auto operator()(const From& from) const
-> std::enable_if_t<std::is_arithmetic_v<From>, From>
// NOTE: this function does not return SaturateType value
// it is user's responsiblity to do further cast or not
template <typename AccType>
CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const
-> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
{
return clamp(from,
type_convert<From>(numeric<To>::lowest()),
type_convert<From>(numeric<To>::max()));
return clamp(a_,
type_convert<AccType>(numeric<SaturateType>::lowest()),
type_convert<AccType>(numeric<SaturateType>::max()));
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -11,6 +11,7 @@
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
......@@ -19,7 +20,9 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
......
......@@ -15,11 +15,14 @@
namespace ck_tile {
/*
* a host side utility, arg parser for
* -[key0]=[value0] -[key1]=[value1] ...
* a host side utility, arg parser for, either
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/
class ArgParser
{
public:
class Arg
{
......@@ -187,6 +190,45 @@ class ArgParser
return value;
}
std::vector<std::string> get_string_vec(const std::string& name,
const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
std::string s = get_str(name);
std::vector<std::string> tokens;
size_t pos = 0;
std::string token;
while((pos = s.find(delimiter)) != std::string::npos)
{
token = s.substr(0, pos);
tokens.push_back(token);
s.erase(0, pos + delimiter.length());
}
tokens.push_back(s);
return tokens;
}
std::vector<int> get_int_vec(const std::string& name, const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
const std::vector<std::string> args = get_string_vec(name, delimiter);
std::vector<int> tokens;
tokens.reserve(static_cast<int>(args.size()));
for(const std::string& token : args)
{
int value = atoi(token.c_str());
tokens.push_back(value);
}
return tokens;
}
private:
std::unordered_map<std::string, Arg> input_map;
std::vector<std::string> keys;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -18,6 +18,114 @@
namespace ck_tile {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
{
return 0;
}
else
{
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
{
return 0;
}
else
{
output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
{
return 0;
}
else
{
acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
}
return std::max(acc_error, midway_error);
}
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
{
return 0;
}
else
{
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
{
return 0;
}
else
{
output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
{
return 0;
}
else
{
acc_error =
std::pow(2, expo - numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
}
return std::max(acc_error, midway_error);
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
......@@ -337,7 +445,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}
......@@ -391,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}
......
......@@ -14,57 +14,41 @@ namespace detail {
template <typename OldLayout>
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
using namespace ck_tile::tensor_layout::convolution;
if constexpr(is_any_of<OldLayout, GNCW, GKCX, GNKW>::value)
{
return {0, 1, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
else if constexpr(is_any_of<OldLayout, GNCHW, GKCYX, GNKHW>::value)
{
return {0, 1, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
else if constexpr(is_any_of<OldLayout, GNCDHW, GKCZYX, GNKDHW>::value)
{
return {0, 1, 2, 3, 4, 5};
}
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
if constexpr(is_any_of<OldLayout, GNWC, GKXC, GNWK>::value)
{
return {0, 1, 3, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
else if constexpr(is_any_of<OldLayout, GNHWC, GKYXC, GNHWK>::value)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
else if constexpr(is_any_of<OldLayout, GNDHWC, GKZYXC, GNDHWK>::value)
{
return {0, 1, 5, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
else if constexpr(is_any_of<OldLayout, NWGC, KXGC, NWGK>::value)
{
return {2, 0, 3, 1};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
else if constexpr(is_any_of<OldLayout, NHWGC, KYXGC, NHWGK>::value)
{
return {3, 0, 4, 1, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
else if constexpr(is_any_of<OldLayout, NDHWGC, KZYXGC, NDHWGK>::value)
{
return {4, 0, 5, 1, 2, 3};
}
......@@ -83,11 +67,11 @@ template <typename InLayout>
CK_TILE_HOST HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
{
using namespace ck_tile::tensor_layout::convolution;
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
if constexpr(is_any_of<InLayout, GNCW, GNCHW, GNCDHW>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
......@@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
else if constexpr(is_any_of<InLayout, GNWC, GNHWC, GNDHWC>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
......@@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
else if constexpr(is_any_of<InLayout, NWGC, NHWGC, NDHWGC>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
......@@ -139,11 +119,11 @@ template <typename WeiLayout>
CK_TILE_HOST HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
{
using namespace ck_tile::tensor_layout::convolution;
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
if constexpr(is_any_of<WeiLayout, KXC, KYXC, KZYXC>::value)
{
if(param.G_ != 1)
{
......@@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
else if constexpr(is_any_of<WeiLayout, GKCX, GKCYX, GKCZYX>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
......@@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
else if constexpr(is_any_of<WeiLayout, GKXC, GKYXC, GKZYXC>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
......@@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
else if constexpr(is_any_of<WeiLayout, KXGC, KYXGC, KZYXGC>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.G_),
......@@ -211,11 +185,11 @@ template <typename OutLayout>
CK_TILE_HOST HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
{
using namespace ck_tile::tensor_layout::convolution;
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
if constexpr(is_any_of<OutLayout, GNKW, GNKHW, GNKDHW>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
......@@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
else if constexpr(is_any_of<OutLayout, GNWK, GNHWK, GNDHWK>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
......@@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
else if constexpr(is_any_of<OutLayout, NWGK, NHWGK, NDHWGK>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
......
......@@ -7,6 +7,7 @@
#include <stdint.h>
#include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename T>
......@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf = nullptr;
}
}
template <typename T>
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
{
if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
ToDevice(t.data());
}
void Realloc(std::size_t mem_size)
{
if(mpDeviceBuf)
......@@ -92,6 +106,27 @@ struct DeviceMem
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
}
// construct a host tensor with type T
template <typename T>
HostTensor<T> ToHost(std::size_t cpySize)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
HostTensor<T> h_({host_elements});
if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
return h_;
}
template <typename T>
HostTensor<T> ToHost()
{
return ToHost<T>(mMemSize);
}
void SetZero() const
{
if(mpDeviceBuf)
......
......@@ -13,6 +13,7 @@
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
namespace ck_tile {
......@@ -22,13 +23,44 @@ struct FillUniformDistribution
float a_{-5.f};
float b_{5.f};
std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
if(threaded)
{
uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
}
template <typename ForwardRange>
......@@ -115,13 +147,44 @@ struct FillNormalDistribution
float mean_{0.f};
float variance_{1.f};
std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
if(threaded)
{
uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
}
template <typename ForwardRange>
......@@ -235,6 +298,44 @@ struct FillMonotonicSeq
}
};
template <typename T, bool IsAscending = true>
struct FillStepRange
{
float start_value_{0};
float end_value_{3};
float step_{1};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, n = start_value_]() mutable {
auto tmp = n;
n += step_;
if constexpr(IsAscending)
{
if(n > end_value_)
n = start_value_;
}
else
{
if(n < end_value_)
n = start_value_;
}
return type_convert<T>(tmp);
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
template <typename T>
struct FillConstant
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -8,12 +8,13 @@
#include <iostream>
#include <iomanip>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include <fstream>
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/ranges.hpp"
namespace ck_tile {
......@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
return HostTensorDescriptor(new_lengths, new_strides);
}
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <typename F, typename... Xs>
struct ParallelTensorFunctor
{
......@@ -590,7 +574,147 @@ struct HostTensor
size() * FromSize / ToSize};
}
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
{
os << t.mDesc;
os << "[";
for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
{
os << type_convert<float>(t.mData[idx]) << " #### ";
}
else
{
os << t.mData[idx];
}
}
os << "]";
return os;
}
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void loadtxt(std::string file_name, std::string dtype = "float")
{
std::ifstream file(file_name);
if(file.is_open())
{
std::string line;
index_t cnt = 0;
while(std::getline(file, line))
{
if(cnt >= static_cast<index_t>(mData.size()))
{
throw std::runtime_error(std::string("data read from file:") + file_name +
" is too big");
}
if(dtype == "float")
{
mData[cnt] = type_convert<T>(std::stof(line));
}
else if(dtype == "int" || dtype == "int32")
{
mData[cnt] = type_convert<T>(std::stoi(line));
}
cnt++;
}
file.close();
if(cnt < static_cast<index_t>(mData.size()))
{
std::cerr << "Warning! reading from file:" << file_name
<< ", does not match the size of this tensor" << std::endl;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void savetxt(std::string file_name, std::string dtype = "float")
{
std::ofstream file(file_name);
if(file.is_open())
{
for(auto& itm : mData)
{
if(dtype == "float")
file << type_convert<float>(itm) << std::endl;
else if(dtype == "int")
file << type_convert<int>(itm) << std::endl;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
file << type_convert<float>(itm) << std::endl;
}
file.close();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
Descriptor mDesc;
Data mData;
};
template <bool is_row_major>
auto host_tensor_descriptor(std::size_t row,
std::size_t col,
std::size_t stride,
bool_constant<is_row_major>)
{
using namespace ck_tile::literals;
if constexpr(is_row_major)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <bool is_row_major>
auto get_default_stride(std::size_t row,
std::size_t col,
std::size_t stride,
bool_constant<is_row_major>)
{
if(stride == 0)
{
if constexpr(is_row_major)
{
return col;
}
else
{
return row;
}
}
else
return stride;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <utility>
namespace ck_tile {
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename Type>
CK_TILE_HOST void reference_batched_transpose(const HostTensor<Type>& x,
HostTensor<Type>& y,
std::string layout_in = "NCHW",
std::string layout_out = "NHWC")
{
const int N = x.mDesc.get_lengths()[0];
auto f = [&](auto batch) {
if(layout_in == "NCHW" && layout_out == "NHWC")
{
const int C = x.mDesc.get_lengths()[1];
const int H = x.mDesc.get_lengths()[2];
const int W = x.mDesc.get_lengths()[3];
for(int c = 0; c < C; ++c)
{
for(int h = 0; h < H; ++h)
{
for(int w = 0; w < W; ++w)
{
Type v_x = x(batch, c, h, w);
y(batch, h, w, c) = v_x;
}
}
}
}
else if(layout_in == "NHWC" && layout_out == "NCHW")
{
const int H = x.mDesc.get_lengths()[1];
const int W = x.mDesc.get_lengths()[2];
const int C = x.mDesc.get_lengths()[3];
for(int h = 0; h < H; ++h)
{
for(int w = 0; w < W; ++w)
{
for(int c = 0; c < C; ++c)
{
Type v_x = x(batch, h, w, c);
y(batch, c, h, w) = v_x;
}
}
}
}
};
make_ParallelTensorFunctor(f, N)(std::thread::hardware_concurrency());
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
// -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
///
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
template <typename AccDataType, // you only need to explcitly set this one
typename Activation, // ck_tile::element_wise::Gelu
typename ADataType,
typename GDataType,
typename DDataType,
typename ODataType,
typename AScaleDataType,
typename GScaleDataType,
typename DScaleDataType,
typename YSmoothScaleDataType,
typename TopkWeightDataType,
typename IndexDataType>
void reference_fused_moe(
const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<IndexDataType>&
sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
const ck_tile::HostTensor<IndexDataType>&
token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
ck_tile::index_t block_m,
ck_tile::index_t tokens,
ck_tile::index_t experts,
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up/down
ck_tile::index_t topk,
ck_tile::index_t gate_only)
{
assert(sorted_token_ids_host.get_num_of_dimension() == 1);
assert(sorted_weight_host.get_num_of_dimension() == 1);
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
assert(num_sorted_tiles_host.get_element_size() == 1);
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
ck_tile::index_t intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2);
ck_tile::index_t intermediate_size_1 = intermediate_size;
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
// assert();
auto f = [&](auto i_flatten) {
ck_tile::index_t i_tile = i_flatten / block_m;
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
ck_tile::index_t i_topk = i_token >> 24;
i_token &= 0xffffff;
if(i_token >= tokens)
return;
(void)token_ids_host;
#else
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens)
return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
#endif
auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
{
acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
}
acc_0(0, i_n) = acc;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
if(gate_only)
{
if(intermediate_size_1 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
Activation{}(y(0, i_n), acc_0(0, i_n));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
else
{
if(intermediate_size_1 * 2 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
AccDataType tmp;
Activation{}(tmp, acc_0(0, i_n));
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
}
}
// second gemm, loop along gemm-n
ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
{
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
}
acc_1(0, i_n) = acc * weight; // multiple weight here
}
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
}
};
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
// reduce
auto r = [&](auto i_token) {
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = type_convert<AccDataType>(0);
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
{
acc += out_topk_tokens(i_token, i_topk, i_n);
}
o_host(i_token, i_n) = type_convert<ODataType>(acc);
}
};
make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
(void)num_sorted_tiles_host;
(void)sa_host;
(void)sg_host;
(void)sd_host;
(void)sy_host;
}
} // namespace ck_tile
......@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
ck_tile::type_convert<AccDataType>(B[b_index]);
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
C[c_index] = ck_tile::type_convert<CDataType>(acc);
}
}
......@@ -97,9 +98,9 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
void reference_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
index_t K,
......@@ -107,78 +108,50 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t stride_b,
index_t stride_c)
{
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(
d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(
d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
<<<numBlocks, numThreadsPerBlock>>>(
a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
return;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
errC = hipFree(d_C);
if(errC != hipSuccess)
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
return;
......
......@@ -8,6 +8,9 @@
namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights,
......@@ -20,8 +23,14 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
{
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(unit_size, num_token));
// allocate a temp buffer, and fill the value with [number_token|topk]
std::vector<std::vector<IndexType>> expert_tokens(
experts,
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
std::vector<IndexType>(unit_size, MOE_SORTING_MOCK_ID(num_token, topk)));
#else
std::vector<IndexType>(unit_size, num_token));
#endif
std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1);
......@@ -42,12 +51,19 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
expert_token_weights[e].resize(new_size);
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
{
expert_tokens[e][i] = num_token;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens[e][i] = MOE_SORTING_MOCK_ID(num_token, topk);
#else
expert_tokens[e][i] = num_token;
#endif
expert_token_weights[e][i] = 0;
}
}
expert_tokens[e][idx] = t;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens[e][idx] = MOE_SORTING_MOCK_ID(t, k);
#else
expert_tokens[e][idx] = t;
#endif
expert_token_weights[e][idx] = w;
expert_slice_idxs[e]++;
}
......@@ -75,4 +91,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
unit_cnt *= unit_size;
return;
}
#undef MOE_SORTING_MOCK_ID
} // namespace ck_tile
......@@ -16,7 +16,7 @@ namespace ck_tile {
*/
template <typename DataType>
CK_TILE_HOST void
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
{
const auto x_len = x.mDesc.get_lengths();
const auto y_len = y.mDesc.get_lengths();
......@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std::vector<size_t> tmp(rank, 0);
for(index_t i = 0; i < rank; i++)
{
tmp[dims[i]] = y_coord[i];
tmp[perm[i]] = y_coord[i];
}
return tmp;
}();
......@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
}
template <typename DataType>
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
{
auto x_shape = x.get_lengths();
ck_tile::index_t rank = perm.size();
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
for(int i = 0; i < static_cast<int>(rank); i++)
{
tmp[i] = x_shape[perm[i]];
}
return tmp;
}();
HostTensor<DataType> y(y_shape);
reference_permute(x, y, perm);
return y;
}
} // namespace ck_tile
......@@ -8,16 +8,40 @@
namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_rmsnorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType,
typename GammaDataType,
typename ComputeDataType,
typename YDataType,
typename InvRmsDataType>
typename InvRmsDataType,
typename Epilogue = reference_rmsnorm2d_default_epilogue>
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n,
HostTensor<YDataType>& y_m_n,
HostTensor<InvRmsDataType>& invRms_m,
ComputeDataType epsilon)
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
{
auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
......@@ -37,13 +61,15 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
auto y = x * divisor * gamma;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y);
acc(m, n) = x * divisor * gamma;
}
epilogue_functor(m, y_m_n, acc);
};
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
......
......@@ -22,7 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>&
// scale = amax / 127 for int8
auto v_scale = type_convert<XDataType>(scale_m(m));
auto v_qx = v_x / v_scale;
qx_m_n(m, n) = saturates<QXDataType>{}(v_qx);
qx_m_n(m, n) = type_convert<QXDataType>(saturates<QXDataType>{}(v_qx));
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
......@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
static constexpr bool kSaveX = Problem::kSaveX;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
......@@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
......@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
});
// compute absmax, each-thread->cross-lane->cross-warp
auto absmax = block_reduce2d(
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
auto absmax = [&]() {
constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
{
return block_reduce2d(y,
reduce_absmax_func.GetIdentityValue<ComputeDataType>(),
reduce_absmax3_func,
sequence<1, 2>{});
}
else
{
return block_reduce2d(
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
}
}();
block_reduce2d_sync(absmax, reduce_max_func);
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
......
......@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
static constexpr bool kSaveX = Problem::kSaveX;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
......@@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
......@@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
y(idx) = type_convert<ComputeDataType>(y_);
});
block_reduce2d(y, absmax, reduce_absmax_func);
constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{});
else
block_reduce2d(y, absmax, reduce_absmax_func);
if constexpr(kSaveX)
move_tile_window(x_window, {0, -Block_N});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment