Unverified Commit 17ed368f authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE][FA] using pk f16_f32 (#1343)

* [CK_TILE][FA] using pk f16_f32

* correct a error
parent e0210316
......@@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
#else
__syncthreads();
#endif
......
......@@ -167,6 +167,10 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif
// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
......
......@@ -110,7 +110,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace impl {
// TODO: this is ugly
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
{
#if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
......@@ -156,6 +156,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif
}
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
static_assert(thread_buffer_size % 2 == 0);
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
// TODO: this is rtz cvt, need be very careful
for(index_t i = 0; i < thread_buffer_size_pk; i++)
{
auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
}
return out_dstr_tensor;
#else
// fallback
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
in_dstr_tensors);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
......@@ -229,8 +260,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float> &&
(SrcTensor::get_thread_buffer_size() % 4 == 0))
{
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor);
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
}
#if CK_TILE_USE_PK_FP16_TILE_CAST
else if constexpr(std::is_same_v<DstType, fp16_t> &&
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 2 == 0))
{
return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
{
......
......@@ -578,8 +578,14 @@ struct BlockFmhaPipelineQRKSVSAsync
randval_dram_window);
}
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
const auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
}();
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment