Unverified Commit 8b49f207 authored by Max Podkorytov's avatar Max Podkorytov Committed by GitHub
Browse files

Merge branch 'develop' into fa-h512

parents 0d59f474 a6b761c3
...@@ -18,6 +18,20 @@ ...@@ -18,6 +18,20 @@
#define CK_USE_OCP_FP8 0 #define CK_USE_OCP_FP8 0
#endif #endif
namespace {
// https://en.cppreference.com/w/cpp/types/conditional
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace
namespace ck { namespace ck {
using f8_fnuz_t = _BitInt(8); using f8_fnuz_t = _BitInt(8);
...@@ -191,11 +205,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -191,11 +205,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
} }
} }
typename __hip_internal::conditional< typename conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>:: typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type retval;
type>::type retval;
if constexpr(we == 5 && is_half && !is_fnuz) if constexpr(we == 5 && is_half && !is_fnuz)
{ {
...@@ -538,11 +551,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -538,11 +551,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
using T_bitwise = typename __hip_internal::conditional< using T_bitwise = typename conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>:: typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
type>::type;
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x); T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
unsigned long long x{x_bitwise}; unsigned long long x{x_bitwise};
......
...@@ -4,13 +4,34 @@ ...@@ -4,13 +4,34 @@
#ifndef CK_AMD_INLINE_ASM_HPP #ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "data_type.hpp"
// TODO: deprecate all amd_assembly_outer_product_xxx // TODO: deprecate all amd_assembly_outer_product_xxx
namespace ck { namespace ck {
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
{
int c;
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d));
return c;
}
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
{
half2_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3" : "=v"(d) : "v"(a), "v"(b), "v"(c));
return d;
}
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
{
half2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
return c;
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
......
...@@ -12,6 +12,17 @@ using bhalf_t = ushort; ...@@ -12,6 +12,17 @@ using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
// custom data type - pack int4 data
struct pk_i4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
__host__ __device__ constexpr operator float() const { return static_cast<int8_t>(data); }
};
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
{ {
// Precondition: x > 1. // Precondition: x > 1.
...@@ -165,6 +176,13 @@ struct scalar_type<int4_t> ...@@ -165,6 +176,13 @@ struct scalar_type<int4_t>
}; };
#endif #endif
template <>
struct scalar_type<pk_i4_t>
{
using type = pk_i4_t;
static constexpr index_t vector_size = 1;
};
template <> template <>
struct scalar_type<f8_fnuz_t> struct scalar_type<f8_fnuz_t>
{ {
...@@ -1044,6 +1062,12 @@ struct nnvb_data_t_selector<bf8_ocp_t> ...@@ -1044,6 +1062,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type; using type = bf8_ocp_t::data_type;
}; };
template <>
struct nnvb_data_t_selector<pk_i4_t>
{
using type = pk_i4_t::type;
};
template <typename T, index_t N> template <typename T, index_t N>
struct non_native_vector_base< struct non_native_vector_base<
T, T,
...@@ -1163,6 +1187,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>> ...@@ -1163,6 +1187,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
static constexpr index_t vector_size = N; static constexpr index_t vector_size = N;
}; };
template <index_t N>
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
{
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation // non-native vector_type implementation
template <typename T> template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
...@@ -1871,6 +1903,11 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type; ...@@ -1871,6 +1903,11 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type; using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type; using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type;
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
......
...@@ -54,7 +54,8 @@ struct DynamicBuffer ...@@ -54,7 +54,8 @@ struct DynamicBuffer
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
...@@ -195,7 +196,8 @@ struct DynamicBuffer ...@@ -195,7 +196,8 @@ struct DynamicBuffer
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false> bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
......
...@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x) ...@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template <> template <>
inline __device__ half_t neg<half_t>(half_t x) inline __device__ half_t neg<half_t>(half_t x)
{ {
return __hneg(x); return __hneg(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
......
...@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector ...@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X,
index_t I, index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false> typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const __host__ __device__ constexpr auto GetAsType(Number<I> i) const
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
...@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector ...@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X,
index_t I, index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false> typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x) __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -465,6 +465,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_ ...@@ -465,6 +465,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
#endif #endif
} }
template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f32 = ck::type_convert<float>(x_l);
auto h_f32 = ck::type_convert<float>(x_h);
return {l_f32, h_f32};
}
template <> template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{ {
......
...@@ -45,5 +45,8 @@ our implementation of different device operators. ...@@ -45,5 +45,8 @@ our implementation of different device operators.
**[ops/epilogue]** **[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues. epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
**[ref]**
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
## examples ## examples
currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder. currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder.
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -54,6 +54,7 @@ ...@@ -54,6 +54,7 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp"
......
...@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe ...@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert( static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
......
...@@ -30,7 +30,7 @@ struct meta_data_buffer ...@@ -30,7 +30,7 @@ struct meta_data_buffer
{ {
constexpr index_t size = sizeof(T); constexpr index_t size = sizeof(T);
auto tmp = bit_cast<array<std::byte, size>>(data); auto tmp = ck_tile::bit_cast<array<std::byte, size>>(data);
for(int i = 0; i < size; i++) for(int i = 0; i < size; i++)
{ {
...@@ -66,7 +66,7 @@ struct meta_data_buffer ...@@ -66,7 +66,7 @@ struct meta_data_buffer
pos++; pos++;
} }
data = bit_cast<T>(tmp); data = ck_tile::bit_cast<T>(tmp);
} }
return data; return data;
...@@ -86,7 +86,7 @@ struct meta_data_buffer ...@@ -86,7 +86,7 @@ struct meta_data_buffer
pos++; pos++;
} }
auto data = bit_cast<T>(tmp); auto data = ck_tile::bit_cast<T>(tmp);
return data; return data;
} }
......
...@@ -29,6 +29,7 @@ struct static_distributed_tensor ...@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>; remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension() CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
namespace ck_tile { namespace ck_tile {
/* /*
* a host side utility, arg parser for * a host side utility, arg parser for, either
* -[key0]=[value0] -[key1]=[value1] ... * -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/ */
class ArgParser class ArgParser
{ {
public: public:
class Arg class Arg
{ {
...@@ -187,6 +190,45 @@ class ArgParser ...@@ -187,6 +190,45 @@ class ArgParser
return value; return value;
} }
std::vector<std::string> get_string_vec(const std::string& name,
const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
std::string s = get_str(name);
std::vector<std::string> tokens;
size_t pos = 0;
std::string token;
while((pos = s.find(delimiter)) != std::string::npos)
{
token = s.substr(0, pos);
tokens.push_back(token);
s.erase(0, pos + delimiter.length());
}
tokens.push_back(s);
return tokens;
}
std::vector<int> get_int_vec(const std::string& name, const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
const std::vector<std::string> args = get_string_vec(name, delimiter);
std::vector<int> tokens;
tokens.reserve(static_cast<int>(args.size()));
for(const std::string& token : args)
{
int value = atoi(token.c_str());
tokens.push_back(value);
}
return tokens;
}
private: private:
std::unordered_map<std::string, Arg> input_map; std::unordered_map<std::string, Arg> input_map;
std::vector<std::string> keys; std::vector<std::string> keys;
......
...@@ -97,9 +97,9 @@ template <typename ADataType, ...@@ -97,9 +97,9 @@ template <typename ADataType,
typename LayoutA, typename LayoutA,
typename LayoutB, typename LayoutB,
typename LayoutC> typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device, void reference_gemm_gpu(ADataType* a_ptr,
DeviceMem& b_device, BDataType* b_ptr,
DeviceMem& c_device, CDataType* c_ptr,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t stride_b, index_t stride_b,
index_t stride_c) index_t stride_c)
{ {
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(
d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(
d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N; int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC> naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); <<<numBlocks, numThreadsPerBlock>>>(
errC = hipMemcpy( a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return; return;
} }
...@@ -191,9 +125,9 @@ template <typename ADataType, ...@@ -191,9 +125,9 @@ template <typename ADataType,
typename LayoutA, typename LayoutA,
typename LayoutB, typename LayoutB,
typename LayoutC> typename LayoutC>
void reference_batched_gemm_gpu(DeviceMem& a_device, void reference_batched_gemm_gpu(ADataType* a_ptr,
DeviceMem& b_device, BDataType* b_ptr,
DeviceMem& c_device, CDataType* c_ptr,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device, ...@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
index_t batch_stride_C, index_t batch_stride_C,
index_t batch_count) index_t batch_count)
{ {
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(d_A,
a_device.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(d_B,
b_device.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N; int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id) for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
{ {
ADataType* d_ATemp = d_A + batch_id * batch_stride_A; ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
BDataType* d_BTemp = d_B + batch_id * batch_stride_B; BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
CDataType* d_CTemp = d_C + batch_id * batch_stride_C; CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC> naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>( <<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c); d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
} }
errC = hipMemcpy(c_device.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return; return;
} }
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -56,6 +56,13 @@ struct CShuffleEpilogue ...@@ -56,6 +56,13 @@ struct CShuffleEpilogue
// No additional shared memory needed // No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return false;
}
template <typename OAccTile> template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{ {
...@@ -111,7 +118,9 @@ struct CShuffleEpilogue ...@@ -111,7 +118,9 @@ struct CShuffleEpilogue
} }
} }
template <typename ODramWindowTmp, typename OAccTile> template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{ {
const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
...@@ -158,12 +167,26 @@ struct CShuffleEpilogue ...@@ -158,12 +167,26 @@ struct CShuffleEpilogue
// Store the tile data to the permuted location // Store the tile data to the permuted location
if constexpr(kPadM || kPadN) if constexpr(kPadM || kPadN)
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence(); buffer_store_fence();
} }
else else
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment