Commit 398f8851 authored by Jing Zhang's avatar Jing Zhang
Browse files

debug i4_to_f16_convert

parent 222e9688
......@@ -34,11 +34,11 @@ using DeviceGemmV2Instance =
16, 16,
1, 1,
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
2, 8, 8, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
2, 32, 32, 1,
1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#else
128,
16, 32,
......
......@@ -223,6 +223,34 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
get_rtol<CDataType>(),
get_atol<CDataType>());
#if 0
std::cout << "a_m_k: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < K; j++)
{
std::cout << ck::type_convert<float>(a_m_k(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "b_k_n: " << std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
ck::pk_i4_t i4x2 = b_k_n(j, i);
int8_t i4 = 0;
if( j % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
std::cout << ck::type_convert<float>(i4) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
......@@ -242,6 +270,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
std::cout << std::endl;
}
#endif
}
if(config.time_kernel)
......
......@@ -30,14 +30,45 @@ __device__ inline half4_t pki4_to_half4(int q)
const int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
return res.template AsType<half4_t>()[Number<0>{}];
}
__device__ inline half2_t pki4_to_half2(pk_i4_t q)
{
#if 0
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f16 = ck::type_convert<ck::half_t>(x_l - 8);
auto h_f16 = ck::type_convert<ck::half_t>(x_h - 8);
return {h_f16, l_f16};
#else
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
int x_l = (x_u8 & 0x0f);
int x_h = (x_u8 & 0xf0) << 12;
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = (x_l | x_h) | EX;
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#endif
}
struct PassThroughPack8
{
template <typename Y, typename X>
......@@ -45,12 +76,24 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}];
#else
vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
}
constexpr const static bool is_pack8_invocable = true;
......
......@@ -1370,6 +1370,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
......
......@@ -1023,8 +1023,11 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1)),
"pk data N cannot be 1");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0,
"pk data N cannot be 1");
}
}
template <typename SrcRefToOriginDisplacement,
......@@ -1123,8 +1126,9 @@ struct ThreadwiseTensorSliceTransfer_v4
using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
//const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
//src_desc, src_data_coord);
const bool is_src_valid = true;
// copy data from src_buf into src_tmp_vector
if constexpr(SrcBuffer::IsDynamicBuffer())
......@@ -1156,8 +1160,9 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
using src_v_t = typename vector_type_maker_t<SrcData, 4>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
......
......@@ -77,11 +77,18 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> &&
(SrcScalarPerVector_ == 1 || DstScalarPerVector_ == 1)),
"pk data N cannot be 1");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1");
static_assert(
SrcVectorDim == DstVectorDim,
"pk_i4_t does not support transpose");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
......@@ -300,6 +307,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
});
#else
#if 0
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
......@@ -362,10 +370,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
});
static_assert(
!(is_same_v<pk_i4_t, remove_cvref_t<SrcData>> && SrcVectorDim != DstVectorDim),
"pk_i4_t does not support transpose");
#endif
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
......@@ -377,6 +382,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
static_assert(false, "no transpose allowed");
#if 0
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
......@@ -395,9 +402,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector_,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector_>{},
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
......@@ -434,11 +441,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
#endif
}
else
{
constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
......@@ -765,7 +773,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector_>{});
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
......@@ -825,7 +833,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector_>{});
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
......@@ -867,8 +875,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
//static constexpr auto src_oob_thread_scratch_desc_ =
//decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch =
......
......@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
......@@ -76,15 +83,6 @@ struct DynamicBuffer
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
// static_assert(element_space_size_ % PackedSize == 0, "");
if constexpr(InvalidElementUseNumericalZeroValue)
{
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
......@@ -203,7 +201,7 @@ struct DynamicBuffer
dst_buf.p_data_,
dst_offset,
is_valid_element,
element_space_size_);
element_space_size_ / PackedSize);
}
template <typename X,
......@@ -237,7 +235,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
......@@ -389,7 +387,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else
{
......@@ -428,7 +426,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if(is_valid_element)
{
......
......@@ -74,6 +74,17 @@ struct ReferenceGemm : public device::BaseOperator
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else if constexpr(is_same_v<ADataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.a_m_k_(m, k);
int8_t i4 = 0;
if(k % 2 == 0)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_a = type_convert<ComputeTypeA>(i4);
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
......@@ -88,12 +99,12 @@ struct ReferenceGemm : public device::BaseOperator
{
pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0;
if(k % 2 == 1)
if(k % 2 == 0)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
arg.b_element_op_(v_b, i4);
v_b = type_convert<ComputeTypeB>(i4);
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment