Commit 4fe49693 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'kaba' into navi3_rel

parents 809d7dfb cc0ffeb7
......@@ -1447,10 +1447,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
1,
0);
v_theother_row = type_convert_sp<SrcData>(temp);
// if (get_thread_local_1d_id() == 0){
// printf("src_offset:%d, dst_offset for this row: %d, dst_offset
// for the other row: %d \n",
// src_offset, dst_offset, dst_offset+DstScalarPerVector);}
if(get_thread_local_1d_id() % 32 < 16)
{
// apply type convert
......
......@@ -521,6 +521,114 @@ amd_buffer_store_impl_raw(const typename vector_type<int8_t, N>::type src_thread
dst_wave_addr_offset + sizeof(int32_t) * 12,
static_cast<index_t>(coherence));
}
else if constexpr(is_same<T, uint8_t>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<uint8x2_t>(tmp);
#endif
}
else if constexpr(N == 4)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<uint8x4_t>(tmp);
#endif
}
else if constexpr(N == 8)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<uint8_t, 8> tmp;
tmp.AsType<uint8x4_t>()(Number<0>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<uint8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
static_cast<index_t>(coherence));
return tmp.AsType<uint8x8_t>()(Number<0>{});
#else
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<uint8x8_t>(tmp);
#endif
}
else if constexpr(N == 16)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<uint8_t, 16> tmp;
tmp.AsType<uint8x4_t>()(Number<0>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<uint8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
static_cast<index_t>(coherence));
tmp.AsType<uint8x4_t>()(Number<2>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int8_t),
static_cast<index_t>(coherence));
tmp.AsType<uint8x4_t>()(Number<3>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int8_t),
static_cast<index_t>(coherence));
return tmp.AsType<uint8x16_t>()(Number<0>{});
#else
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<uint8x16_t>(tmp);
#endif
}
}
}
template <typename T,
......
......@@ -133,6 +133,13 @@ struct scalar_type<int8_t>
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<uint8_t>
{
using type = uint8_t;
static constexpr index_t vector_size = 1;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct scalar_type<int4_t>
......@@ -1037,6 +1044,227 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8
// i8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
template <>
inline __host__ __device__ constexpr int type_convert_sp<int, float>(float x)
{
union
{
float fp32;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr float type_convert_sp<float, int>(int x)
{
union
{
int int32;
float fp32;
} u = {x};
return u.fp32;
}
template <>
inline __host__ __device__ constexpr int type_convert_sp<int, half_t>(half_t x)
{
union
{
half_t fp16;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr half_t type_convert_sp<half_t, int>(int x)
{
union
{
int int32;
half_t fp16;
} u = {x};
return u.fp16;
}
// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16);
}
// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<half_t>(x_fp32);
}
// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int8_t>(x_fp32);
}
// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert int8 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, int8_t>(int8_t x)
{
// TODO: replace it with fast_converter
float x_fp32 = static_cast<float>(x);
return type_convert<half_t>(x_fp32);
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
template <typename T>
struct NumericLimits
......
......@@ -133,6 +133,252 @@ struct ReferenceBatchedGemm : public device::BaseOperator
}
};
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceBatchedGemm_MQA : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_g0_g1_m_k,
const Tensor<BDataType>& b_g0_1_k_n,
Tensor<CDataType>& c_g0_g1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_g0_g1_m_k_{a_g0_g1_m_k},
b_g0_1_k_n_{b_g0_1_k_n},
c_g0_g1_m_n_{c_g0_g1_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_g0_g1_m_k_;
const Tensor<BDataType>& b_g0_1_k_n_;
Tensor<CDataType>& c_g0_g1_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceBatchedGemm_MQA::Argument;
float Run(const Argument& arg)
{
auto f_g0g1mk_g01kn_g0g1mn = [&](auto g0, auto g1, auto m, auto n) {
const int K = arg.a_g0_g1_m_k_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
arg.a_element_op_(v_a, arg.a_g0_g1_m_k_(g0, g1, m, k));
arg.b_element_op_(v_b, arg.b_g0_1_k_n_(g0, 0, k, n));
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_g0_g1_m_n_(g0, g1, m, n) = ck::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f_g0g1mk_g01kn_g0g1mn,
arg.c_g0_g1_m_n_.mDesc.GetLengths()[0],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[1],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[2],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_g0_g1_m_k,
const Tensor<BDataType>& b_g0_1_k_n,
Tensor<CDataType>& c_g0_g1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{
a_g0_g1_m_k, b_g0_1_k_n, c_g0_g1_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceBatchedGemm_MQA"
<< std::endl;
// clang-format on
return str.str();
}
};
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t QueryGroupNumber>
struct ReferenceBatchedGemm_GQA : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_g0_g1_m_k,
const Tensor<BDataType>& b_g0_gq_k_n,
Tensor<CDataType>& c_g0_g1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_g0_g1_m_k_{a_g0_g1_m_k},
b_g0_gq_k_n_{b_g0_gq_k_n},
c_g0_g1_m_n_{c_g0_g1_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_g0_g1_m_k_;
const Tensor<BDataType>& b_g0_gq_k_n_;
Tensor<CDataType>& c_g0_g1_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceBatchedGemm_GQA::Argument;
float Run(const Argument& arg)
{
auto f_g0g1mk_g0gqkn_g0g1mn = [&](auto g0, auto g1, auto m, auto n) {
const int G1 = arg.a_g0_g1_m_k_.mDesc.GetLengths()[1];
const int K = arg.a_g0_g1_m_k_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
arg.a_element_op_(v_a, arg.a_g0_g1_m_k_(g0, g1, m, k));
arg.b_element_op_(v_b, arg.b_g0_gq_k_n_(g0, g1 * QueryGroupNumber / G1, k, n));
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_g0_g1_m_n_(g0, g1, m, n) = ck::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f_g0g1mk_g0gqkn_g0g1mn,
arg.c_g0_g1_m_n_.mDesc.GetLengths()[0],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[1],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[2],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_g0_g1_m_k,
const Tensor<BDataType>& b_g0_gq_k_n,
Tensor<CDataType>& c_g0_g1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{
a_g0_g1_m_k, b_g0_gq_k_n, c_g0_g1_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceBatchedGemm_GQA"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename ScaleDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferencefpAintBGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& scale_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
scale_k_n_{scale_k_n},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
const Tensor<ScaleDataType>& scale_k_n_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferencefpAintBGemm::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
ScaleDataType v_scale;
ADataType v_converted_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
// same for scale matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_scale,
arg.scale_k_n_(k, n));
}
else
{
arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
}
v_converted_b = type_convert<ADataType>(v_b) * v_scale;
v_acc += ck::type_convert<AccDataType>(v_a) *
ck::type_convert<AccDataType>(v_converted_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& scale_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, scale_k_n, c_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment