Commit b5083bfe authored by aska-0096's avatar aska-0096
Browse files

Fp16AInt8B_GEMM sanity

parent 5cf73a5e
......@@ -14,6 +14,8 @@
// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType
// The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType
//TODO: Current implementation consume more VGPR than expected.
using ADataType = ck::half_t;
using QuantDataType = int8_t;
using BDataType = uint8_t;
......@@ -49,13 +51,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_
GemmDefault,
1, // Prefetch stage
128, // BlockSize
128, // MPerBlock
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
......
......@@ -67,68 +67,6 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
UnsignedWeightPreprocessor<QuantDataType> preprocessor;
Tensor<BDataType> b_k_n = preprocessor(quant_b_k_n);
#if 0
printf("Matrix A:\n");
for (int im = 0; im < M; im++)
{
for (int ik = 0; ik < K; ik++)
{
if(ik % 16 == 0){
printf("|");
}
printf(" %04x", *(reinterpret_cast<uint16_t*>(&a_m_k(im,ik))));
}
printf("\n");
}
#endif
#if 0
printf("Matrix QuantB:\n");
for (int in = 0; in < N; in++)
{
for (int ik = 0; ik < K; ik++)
{
if(ik % 16 == 0){
printf("|");
}
printf(" %02x", *(reinterpret_cast<uint8_t*>(&quant_b_k_n(ik,in))));
}
printf("\n");
}
#endif
#if 0
printf("Matrix Scale:\n");
for(int in = 0; in < N; in++)
{
for(int ik = 0; ik < 1; ik++)
{
if(ik % 16 == 0)
{
printf("|");
}
printf(" %04x", *(reinterpret_cast<uint16_t*>(&scale_k_n(ik, in))));
}
printf("\n");
}
#endif
#if 0
printf("Matrix B:\n");
for (int in = 0; in < N; in++)
{
for (int ik = 0; ik < K; ik++)
{
if(ik % 16 == 0){
printf("|");
}
printf(" %02x", b_k_n(ik,in));
}
printf("\n");
}
#endif
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......
......@@ -408,12 +408,7 @@ struct Blockwise_fpAintB_GemmWMMA
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
#if 0
printf("Tid: %03d, n: %02d, scale_thread_buf: %04x\n",
get_thread_local_1d_id(), n0.value,
*(reinterpret_cast<const uint16_t*>(&scale_thread_buf[n0]))
);
#endif
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
......@@ -448,72 +443,7 @@ struct Blockwise_fpAintB_GemmWMMA
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
if(true)
{
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, a_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<0>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<1>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<2>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<3>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<4>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<5>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<6>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<7>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<8>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<9>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<10>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<11>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<12>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<13>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<14>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<15>{}]))
);
#endif
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, b_thread_buf: %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
b_thread_buf[Number<0>{}],
b_thread_buf[Number<1>{}],
b_thread_buf[Number<2>{}],
b_thread_buf[Number<3>{}],
b_thread_buf[Number<4>{}],
b_thread_buf[Number<5>{}],
b_thread_buf[Number<6>{}],
b_thread_buf[Number<7>{}],
b_thread_buf[Number<8>{}],
b_thread_buf[Number<9>{}],
b_thread_buf[Number<10>{}],
b_thread_buf[Number<11>{}],
b_thread_buf[Number<12>{}],
b_thread_buf[Number<13>{}],
b_thread_buf[Number<14>{}],
b_thread_buf[Number<15>{}]
);
#endif
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, converted_b_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<0>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<1>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<2>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<3>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<4>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<5>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<6>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<7>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<8>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<9>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<10>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<11>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<12>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<13>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<14>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<15>{}]))
);
#endif
}
vector_type<ADataType, WmmaK> a_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
......
......@@ -398,23 +398,12 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
// printf("Tid: %03d, uint8_4: %08x\n",
// get_thread_local_1d_id(),
// uint8_4);
static constexpr uint32_t byte_selector_01 = 0x05010500;
static constexpr uint32_t byte_selector_23 = 0x05030502;
static constexpr uint32_t fp16_adder = 0x64646464;
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
// printf("Tid: %03d, Part1 converted: %08x | %08x\n",
// get_thread_local_1d_id(),
// half_2[Number<0>{}],
// half_2[Number<1>{}]);
// Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed
// integer as fp16.
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n"
: "=v"(half_2[0])
......@@ -422,10 +411,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
// printf("Tid: %03d, Part2 converted: %08x | %08x\n",
// get_thread_local_1d_id(),
// half_2[Number<0>{}],
// half_2[Number<1>{}]);
return Output;
}
......
......@@ -52,14 +52,6 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
if(false && get_thread_local_1d_id() == 0)
{
printf("lds_size: %lu\n", GridwiseGemm::SharedMemTrait::lds_size);
printf("lds_a_size: %d\n", GridwiseGemm::SharedMemTrait::a_block_space_size_aligned);
printf("lds_b_size: %d\n", GridwiseGemm::SharedMemTrait::b_block_space_size_aligned);
printf("lds_scale_size: %d\n",
GridwiseGemm::SharedMemTrait::scale_block_space_size_aligned);
}
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
......@@ -805,7 +797,7 @@ struct GridwiseFpAintBGemm_Wmma
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
SharedMemTrait::b_block_space_size_aligned);
// printf("b_lds_offset: %lu\n", SharedMemTrait::b_block_space_offset);
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
......@@ -886,7 +878,6 @@ struct GridwiseFpAintBGemm_Wmma
auto scale_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ScaleDataType*>(p_shared) + SharedMemTrait::scale_block_space_offset,
SharedMemTrait::scale_block_space_size_aligned);
// printf("scale_lds_offset: %lu\n", SharedMemTrait::scale_block_space_offset);
auto scale_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
......
......@@ -1143,9 +1143,7 @@ struct ThreadwiseTensorSliceTransfer_v4
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
#if 0
printf("Tid: %03d, LDS read offset: %d\n", get_thread_local_1d_id(), src_data_coord.GetOffset());
#endif
// copy data from src_buf into src_tmp_vector
if constexpr(SrcBuffer::IsDynamicBuffer())
{
......
......@@ -207,13 +207,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
if(false)
{
printf("Tid: %03d, a_grid_buf: %04x\n",
get_thread_local_1d_id(),
*(reinterpret_cast<const uint16_t*>(
&src_vector_container.template AsType<SrcData>()[Number<0>{}])));
}
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>(
......@@ -448,9 +442,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
#if 0
printf("Tid: %03d, LDS write offset: %d\n", get_thread_local_1d_id(), dst_coord_.GetOffset());
#endif
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
......
find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
# git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
# find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment