"...composable_kernel.git" did not exist on "fc7e83ee523a7fc9c675cfeb36c6fe2504534fc7"
Commit bbb2cb69 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments

parent bc868da4
...@@ -19,8 +19,9 @@ using AElementOp = PassThrough; ...@@ -19,8 +19,9 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteB = true; static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t KPerBlock = 128;
// clang-format off // clang-format off
...@@ -29,18 +30,6 @@ using DeviceGemmV2Instance = ...@@ -29,18 +30,6 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, AElementOp, BElementOp, CElementOp, GemmDefault,
#if 0
128,
16, 128,
KPerBlock, 8, 32,
16, 16,
1, 4,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#elif 1
128, 128,
16, 64, 16, 64,
KPerBlock, 8, 32, KPerBlock, 8, 32,
...@@ -51,19 +40,7 @@ using DeviceGemmV2Instance = ...@@ -51,19 +40,7 @@ using DeviceGemmV2Instance =
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
#else ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>;
128,
16, 32,
KPerBlock, 8, 32,
16, 16,
1, 1,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#endif
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, false, PermuteB>;
// clang-format on // clang-format on
...@@ -156,7 +133,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -156,7 +133,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute // weight permute
...@@ -241,55 +218,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -241,55 +218,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
"Error: Incorrect results!", "Error: Incorrect results!",
get_rtol<CDataType>(), get_rtol<CDataType>(),
get_atol<CDataType>()); get_atol<CDataType>());
#if 0
std::cout << "a_m_k: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < K; j++)
{
std::cout << ck::type_convert<float>(a_m_k(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "b_k_n: " << std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
ck::pk_i4_t i4x2 = b_k_n(j, i);
int8_t i4 = 0;
if( j % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
std::cout << ck::type_convert<float>(i4) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
#endif
} }
if(config.time_kernel) if(config.time_kernel)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
...@@ -27,7 +27,6 @@ using DeviceGemmV2Instance = ...@@ -27,7 +27,6 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, AElementOp, BElementOp, CElementOp, GemmDefault,
#if 1
64, 64,
16, 16, 16, 16,
256, 8, 16, 256, 8, 16,
...@@ -39,19 +38,6 @@ using DeviceGemmV2Instance = ...@@ -39,19 +38,6 @@ using DeviceGemmV2Instance =
2, 16, 16, 0, 2, 16, 16, 0,
1, 1, S<1, 16, 1, 4>, 4, 1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#else
128,
16, 32,
128, 8, 16,
16, 16,
1, 1,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......
...@@ -21,6 +21,7 @@ using CElementOp = PassThrough; ...@@ -21,6 +21,7 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true; static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t KPerBlock = 128;
...@@ -31,7 +32,6 @@ using DeviceGemmV2Instance = ...@@ -31,7 +32,6 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, AElementOp, BElementOp, CElementOp, GemmDefault,
#if 0
128, 128,
16, 128, 16, 128,
KPerBlock, 8, 32, KPerBlock, 8, 32,
...@@ -42,19 +42,7 @@ using DeviceGemmV2Instance = ...@@ -42,19 +42,7 @@ using DeviceGemmV2Instance =
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
#else ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>;
128,
16, 64,
KPerBlock, 8, 32,
16, 16,
1, 2,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#endif
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, false, PermuteB>;
// clang-format on // clang-format on
...@@ -283,55 +271,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -283,55 +271,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
"Error: Incorrect results!", "Error: Incorrect results!",
get_rtol<CDataType>(), get_rtol<CDataType>(),
get_atol<CDataType>()); get_atol<CDataType>());
#if 0
std::cout << "a_m_k: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < K; j++)
{
std::cout << ck::type_convert<float>(a_m_k(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "b_k_n: " << std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
ck::pk_i4_t i4x2 = b_k_n(j, i);
int8_t i4 = 0;
if( j % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
std::cout << ck::type_convert<float>(i4) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
#endif
} }
if(config.time_kernel) if(config.time_kernel)
......
...@@ -325,9 +325,13 @@ struct Tensor ...@@ -325,9 +325,13 @@ struct Tensor
std::size_t GetElementSpaceSize() const std::size_t GetElementSpaceSize() const
{ {
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>) if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
return mDesc.GetElementSpaceSize() / 2; {
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
else else
{
return mDesc.GetElementSpaceSize(); return mDesc.GetElementSpaceSize();
}
} }
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
......
...@@ -165,8 +165,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -165,8 +165,7 @@ struct StaticTensorTupleOfVectorBuffer
// Get X // Get X
// Idx is for S, not X. Idx should be aligned with X // Idx is for S, not X. Idx should be aligned with X
template <typename X, template <typename X, typename Idx>
typename Idx>
__host__ __device__ constexpr X GetAsType(Idx) const __host__ __device__ constexpr X GetAsType(Idx) const
{ {
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
...@@ -196,8 +195,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -196,8 +195,7 @@ struct StaticTensorTupleOfVectorBuffer
// Set X // Set X
// Idx is for S, not X. Idx should be aligned with X // Idx is for S, not X. Idx should be aligned with X
template <typename X, template <typename X, typename Idx>
typename Idx>
__host__ __device__ constexpr void SetAsType(Idx, X x) __host__ __device__ constexpr void SetAsType(Idx, X x)
{ {
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
......
...@@ -11,9 +11,8 @@ namespace ck { ...@@ -11,9 +11,8 @@ namespace ck {
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); // custom data type - pack int4 data
//using pk_i4_t = uint8_t;
struct pk_i4_t struct pk_i4_t
{ {
using type = int8_t; using type = int8_t;
...@@ -1894,18 +1893,11 @@ using bf8x32_t = bf8x32_fnuz_t; ...@@ -1894,18 +1893,11 @@ using bf8x32_t = bf8x32_fnuz_t;
using bf8x64_t = bf8x64_fnuz_t; using bf8x64_t = bf8x64_fnuz_t;
#endif #endif
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type; using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type; using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type; using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type;
// u8
// using uint8x2_t = typename vector_type<uint8_t, 2>::type;
// using uint8x4_t = typename vector_type<uint8_t, 4>::type;
// using uint8x8_t = typename vector_type<uint8_t, 8>::type;
// using uint8x16_t = typename vector_type<uint8_t, 16>::type;
// using uint8x32_t = typename vector_type<uint8_t, 32>::type;
// using uint8x64_t = typename vector_type<uint8_t, 64>::type;
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
......
...@@ -114,8 +114,7 @@ struct StaticBufferTupleOfVector ...@@ -114,8 +114,7 @@ struct StaticBufferTupleOfVector
// Get X // Get X
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X, index_t I>
index_t I>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const __host__ __device__ constexpr auto GetAsType(Number<I> i) const
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
...@@ -131,8 +130,7 @@ struct StaticBufferTupleOfVector ...@@ -131,8 +130,7 @@ struct StaticBufferTupleOfVector
// Set X // Set X
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X, index_t I>
index_t I>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x) __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment