Commit bbb2cb69 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments

parent bc868da4
......@@ -19,8 +19,9 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteB = true;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 128;
// clang-format off
......@@ -29,18 +30,6 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
#if 0
128,
16, 128,
KPerBlock, 8, 32,
16, 16,
1, 4,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#elif 1
128,
16, 64,
KPerBlock, 8, 32,
......@@ -51,19 +40,7 @@ using DeviceGemmV2Instance =
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#else
128,
16, 32,
KPerBlock, 8, 32,
16, 16,
1, 1,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#endif
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, false, PermuteB>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>;
// clang-format on
......@@ -156,7 +133,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
......@@ -241,55 +218,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
#if 0
std::cout << "a_m_k: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < K; j++)
{
std::cout << ck::type_convert<float>(a_m_k(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "b_k_n: " << std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
ck::pk_i4_t i4x2 = b_k_n(j, i);
int8_t i4 = 0;
if( j % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
std::cout << ck::type_convert<float>(i4) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
#endif
}
if(config.time_kernel)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
......@@ -27,7 +27,6 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
#if 1
64,
16, 16,
256, 8, 16,
......@@ -39,19 +38,6 @@ using DeviceGemmV2Instance =
2, 16, 16, 0,
1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#else
128,
16, 32,
128, 8, 16,
16, 16,
1, 1,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#endif
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......
......@@ -21,6 +21,7 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 128;
......@@ -31,7 +32,6 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
#if 0
128,
16, 128,
KPerBlock, 8, 32,
......@@ -42,19 +42,7 @@ using DeviceGemmV2Instance =
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#else
128,
16, 64,
KPerBlock, 8, 32,
16, 16,
1, 2,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#endif
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, false, PermuteB>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>;
// clang-format on
......@@ -283,55 +271,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
#if 0
std::cout << "a_m_k: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < K; j++)
{
std::cout << ck::type_convert<float>(a_m_k(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "b_k_n: " << std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
ck::pk_i4_t i4x2 = b_k_n(j, i);
int8_t i4 = 0;
if( j % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
std::cout << ck::type_convert<float>(i4) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
#endif
}
if(config.time_kernel)
......
......@@ -325,9 +325,13 @@ struct Tensor
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
return mDesc.GetElementSpaceSize() / 2;
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
else
{
return mDesc.GetElementSpaceSize();
}
}
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
......
......@@ -165,8 +165,7 @@ struct StaticTensorTupleOfVectorBuffer
// Get X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx>
template <typename X, typename Idx>
__host__ __device__ constexpr X GetAsType(Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
......@@ -196,8 +195,7 @@ struct StaticTensorTupleOfVectorBuffer
// Set X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx>
template <typename X, typename Idx>
__host__ __device__ constexpr void SetAsType(Idx, X x)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
......
......@@ -11,9 +11,8 @@ namespace ck {
using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
//using pk_i4_t = uint8_t;
// custom data type - pack int4 data
struct pk_i4_t
{
using type = int8_t;
......@@ -1894,18 +1893,11 @@ using bf8x32_t = bf8x32_fnuz_t;
using bf8x64_t = bf8x64_fnuz_t;
#endif
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type;
// u8
// using uint8x2_t = typename vector_type<uint8_t, 2>::type;
// using uint8x4_t = typename vector_type<uint8_t, 4>::type;
// using uint8x8_t = typename vector_type<uint8_t, 8>::type;
// using uint8x16_t = typename vector_type<uint8_t, 16>::type;
// using uint8x32_t = typename vector_type<uint8_t, 32>::type;
// using uint8x64_t = typename vector_type<uint8_t, 64>::type;
template <typename T>
struct NumericLimits
{
......
......@@ -114,8 +114,7 @@ struct StaticBufferTupleOfVector
// Get X
// i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I>
template <typename X, index_t I>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const
{
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......@@ -131,8 +130,7 @@ struct StaticBufferTupleOfVector
// Set X
// i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I>
template <typename X, index_t I>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment