Commit d783a8cf authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'develop' into feature/use-larger-tile-size-for-chunk-prefill

parents 1b130866 4cb3d7d7
...@@ -761,7 +761,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -761,7 +761,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float time{0.f}; float time{0.f};
hip_check_error( hip_check_error(
hipMemcpyWithStream(dev_gemm_kargs, hipMemcpyAsync(dev_gemm_kargs,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
......
...@@ -940,7 +940,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -940,7 +940,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* p_host_kernel_args) const const void* p_host_kernel_args) const
{ {
arg.p_dev_gemm_args_ = p_dev_kernel_args; arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args, hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
p_host_kernel_args, p_host_kernel_args,
GetDeviceKernelArgSize(&arg), GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
......
...@@ -557,10 +557,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -557,10 +557,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
} }
} }
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, hipGetErrorString(
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
......
...@@ -421,7 +421,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -421,7 +421,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
hip_check_error( hip_check_error(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
......
...@@ -38,8 +38,7 @@ __global__ void ...@@ -38,8 +38,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
......
...@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, fp8_storage_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -843,8 +845,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -843,8 +845,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#else #else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#endif #endif
} }
...@@ -873,8 +875,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -873,8 +875,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(customized_value); return src_thread_element_valid ? tmp : vector_t(customized_value);
} }
......
This diff is collapsed.
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
namespace ck { namespace ck {
// Define the common macro for gfx94x models // Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__ #define __gfx94__
#endif #endif
......
This diff is collapsed.
...@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x) ...@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); }; static inline __host__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x) static inline __host__ bool isnan(int4_t x)
...@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x) ...@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); }; static inline __device__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
static inline __device__ half_t sqrt(half_t x) static inline __device__ half_t sqrt(half_t x)
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/ck.hpp"
namespace ck { namespace ck {
// Pseudo random number generator // Pseudo random number generator
...@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// version for fp16 // version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false> template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
uint16_t x = *(reinterpret_cast<uint16_t*>(&val)); uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
...@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// return 0 if data is not fp16 or fp32 // return 0 if data is not fp16 or fp32
template <typename T, template <
typename T,
uint32_t seed_t, uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false> std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{ {
std::ignore = id; std::ignore = id;
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/utility/array.hpp" #include "ck/utility/array.hpp"
namespace ck { namespace ck {
// Define the common macro for gfx94x models // Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__ #define __gfx94__
#endif #endif
...@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
template <>
inline __host__ __device__ constexpr f8_ocp_t type_convert<f8_ocp_t, int>(int x)
{
return f8_ocp_t{type_convert<f8_ocp_t::data_type>(x)};
}
template <>
inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int x)
{
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
}
// Convert X to Y // Convert X to Y
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x) __host__ __device__ constexpr Y type_convert_sp(X x)
...@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); ...@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding // convert fp32 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
...@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils:: return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
#endif #endif
} }
// convert fp16 to fp8 with stochastic rounding // convert fp16 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x)); return f8_convert_sr<f8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( f8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to bf8 with stochastic rounding // convert fp32 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
...@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) ...@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils:: return utils::cast_to_f8<float,
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp16 to bf8 with stochastic rounding // convert fp16 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<bf8_t>(type_convert<float>(x)); return f8_convert_sr<bf8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
...@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); ...@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even // convert fp32 to fp8 with rounding to nearest even
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) ...@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
#endif #endif
} }
// convert fp16 to fp8 with rounding to nearest even // convert fp16 to fp8 with rounding to nearest even
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x)); return f8_convert_rne<f8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( f8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to bf8 with rounding to nearest even // convert fp32 to bf8 with rounding to nearest even
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) ...@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<float,
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp16 to bf8 with rounding to nearest even // convert fp16 to bf8 with rounding to nearest even
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x)); return f8_convert_rne<bf8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_fnuz_t>(x);
#endif #endif
} }
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x); return f8_convert_sr<f8_ocp_t>(x);
#else #else
return f8_convert_rne<f8_t>(x); return f8_convert_rne<f8_ocp_t>(x);
#endif #endif
} }
// convert fp8 to fp32 // convert fp8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
float fval; float fval;
...@@ -392,30 +427,44 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) ...@@ -392,30 +427,44 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x); return utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(x);
#endif #endif
} }
template <> template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x) inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
const auto i16val = bit_cast<uint16_t>(x); const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x); const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
vector_type<float, 2> f32x2_v; vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) = f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]); f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) = f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]); f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}]; return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif #endif
} }
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(
x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<0>{}]),
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<1>{}])};
#endif
}
template <> template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{ {
...@@ -428,42 +477,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) ...@@ -428,42 +477,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8 // convert fp16 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x); return f8_convert_sr<f8_fnuz_t>(x);
#else #else
return f8_convert_rne<f8_t>(x); return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif #endif
} }
// convert fp8 to fp16 // convert fp8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<f8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_fnuz_t>(x);
#endif #endif
} }
// convert fp32 to bf8 // convert fp32 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x) inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x); return f8_convert_sr<bf8_ocp_t>(x);
#else #else
return f8_convert_rne<bf8_t>(x); return f8_convert_rne<bf8_ocp_t>(x);
#endif #endif
} }
// convert bf8 to fp32 // convert bf8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
float fval; float fval;
...@@ -473,31 +544,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) ...@@ -473,31 +544,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x); return utils::cast_from_f8<bf8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_fnuz_t>(x);
#endif #endif
} }
// convert fp16 to bf8 // convert fp16 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x); return f8_convert_sr<bf8_ocp_t>(x);
#else #else
return f8_convert_rne<bf8_t>(x); return f8_convert_rne<bf8_ocp_t>(x);
#endif #endif
} }
// convert bf8 to fp16 // convert bf8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
#endif #endif
} }
......
# ck_tile [Back to the main page](../../README.md)
# Composable Kernel Tile
## concept ## concept
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator `ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time. - tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace ck_tile {
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck_tile
...@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device,
return; return;
} }
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(d_A,
a_device.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(d_B,
b_device.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
{
ADataType* d_ATemp = d_A + batch_id * batch_stride_A;
BDataType* d_BTemp = d_B + batch_id * batch_stride_B;
CDataType* d_CTemp = d_C + batch_id * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
errC = hipMemcpy(c_device.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -998,14 +998,14 @@ struct FmhaFwdKernel ...@@ -998,14 +998,14 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
} }
}(); }();
const auto k_dram = [&]() { const auto k_dram = [&]() {
...@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel ...@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
k_dram_naive, k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
}(); }();
const auto v_dram = [&]() { const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel ...@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_transposed, v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<kPadHeadDimV, false>{});
} }
else else
{ {
...@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel ...@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_naive, v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<false, kPadSeqLenK>{});
} }
}(); }();
...@@ -1097,9 +1097,8 @@ struct FmhaFwdKernel ...@@ -1097,9 +1097,8 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentBias>{}, number<FmhaPipeline::kAlignmentBias>{},
number<1>{}); number<1>{});
return pad_tensor_view(bias_dram_naive, return pad_tensor_view(
bias_dram_window_lengths, bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
sequence<kPadSeqLenQ, kPadSeqLenK>{});
}(); }();
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
......
...@@ -339,7 +339,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -339,7 +339,7 @@ struct FmhaFwdSplitKVCombineKernel
number<FmhaPipeline::kAlignmentOacc>{}, number<FmhaPipeline::kAlignmentOacc>{},
number<1>{}); number<1>{});
auto o_acc_dram_view = pad_tensor_view( const auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive, o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}), make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{}); sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
......
...@@ -623,14 +623,14 @@ struct FmhaFwdSplitKVKernel ...@@ -623,14 +623,14 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
} }
}(); }();
...@@ -645,7 +645,7 @@ struct FmhaFwdSplitKVKernel ...@@ -645,7 +645,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view( return pad_tensor_view(
k_dram_naive, k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
}; };
const auto k_dram = [&]() { const auto k_dram = [&]() {
if constexpr(kIsPagedKV) if constexpr(kIsPagedKV)
...@@ -678,7 +678,7 @@ struct FmhaFwdSplitKVKernel ...@@ -678,7 +678,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_transposed, v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<kPadHeadDimV, false>{});
} }
else else
{ {
...@@ -692,7 +692,7 @@ struct FmhaFwdSplitKVKernel ...@@ -692,7 +692,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_naive, v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<false, kPadSeqLenK>{});
} }
}; };
const auto v_dram = [&]() { const auto v_dram = [&]() {
...@@ -804,9 +804,8 @@ struct FmhaFwdSplitKVKernel ...@@ -804,9 +804,8 @@ struct FmhaFwdSplitKVKernel
number<FmhaPipeline::kAlignmentBias>{}, number<FmhaPipeline::kAlignmentBias>{},
number<1>{}); number<1>{});
return pad_tensor_view(bias_dram_naive, return pad_tensor_view(
bias_dram_window_lengths, bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
sequence<kPadSeqLenQ, kPadSeqLenK>{});
}(); }();
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
......
...@@ -25,6 +25,10 @@ ...@@ -25,6 +25,10 @@
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment