Commit 5dee1b64 authored by Ville Pietilä's avatar Ville Pietilä
Browse files

Merge remote-tracking branch 'origin/develop' into vpietila/ggemm-profiling

parents 870c3a76 355893cd
......@@ -4,7 +4,7 @@
#pragma once
namespace ck {
// Define the common macro for gfx94x models
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
......
This diff is collapsed.
......@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __host__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x)
......@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00;
};
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __device__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
static inline __device__ half_t sqrt(half_t x)
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
// Pseudo random number generator
......@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
......@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// return 0 if data is not fp16 or fp32
template <typename T,
template <
typename T,
uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
std::ignore = id;
......
This diff is collapsed.
# ck_tile
[Back to the main page](../../README.md)
# Composable Kernel Tile
## concept
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace ck_tile {
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck_tile
......@@ -339,7 +339,7 @@ struct FmhaFwdSplitKVCombineKernel
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
auto o_acc_dram_view = pad_tensor_view(
const auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
......
......@@ -623,14 +623,14 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
}
}();
......@@ -645,7 +645,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
};
const auto k_dram = [&]() {
if constexpr(kIsPagedKV)
......@@ -678,7 +678,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
sequence<kPadHeadDimV, false>{});
}
else
{
......@@ -692,7 +692,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
sequence<false, kPadSeqLenK>{});
}
};
const auto v_dram = [&]() {
......@@ -804,9 +804,8 @@ struct FmhaFwdSplitKVKernel
number<FmhaPipeline::kAlignmentBias>{},
number<1>{});
return pad_tensor_view(bias_dram_naive,
bias_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
return pad_tensor_view(
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
}();
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
......
......@@ -25,6 +25,7 @@
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
......
......@@ -66,6 +66,79 @@ struct GemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs)
{
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
{
return false;
}
if(kargs.K % GemmPipeline::VectorSizeA != 0)
{
return false;
}
}
else
{
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
{
return false;
}
if(kargs.M % GemmPipeline::VectorSizeA != 0)
{
return false;
}
}
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
{
return false;
}
if(kargs.N % GemmPipeline::VectorSizeB != 0)
{
return false;
}
}
else
{
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
{
return false;
}
if(kargs.K % GemmPipeline::VectorSizeB != 0)
{
return false;
}
}
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
{
return false;
}
if(kargs.N % GemmPipeline::VectorSizeC != 0)
{
return false;
}
}
else
{
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
{
return false;
}
if(kargs.M % GemmPipeline::VectorSizeC != 0)
{
return false;
}
}
return true;
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
......
This diff is collapsed.
......@@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances(
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, false>{});
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
}
void add_device_pool3d_fwd_ndhwc_index_f8_instances(
......@@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances(
instances)
{
add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, true>{});
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment