Commit 65cfb2a1 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 398f8851
......@@ -55,8 +55,8 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
#else
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
int x_l = (x_u8 & 0x0f);
int x_h = (x_u8 & 0xf0) << 12;
int x_l = (x_u8 & 0x0f);
int x_h = (x_u8 & 0xf0) << 12;
const int EX = 0x64006400;
......@@ -66,7 +66,6 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#endif
}
struct PassThroughPack8
......@@ -87,12 +86,16 @@ struct PassThroughPack8
vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}];
y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
}
......
......@@ -1370,7 +1370,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
......
......@@ -1025,8 +1025,7 @@ struct ThreadwiseTensorSliceTransfer_v4
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0,
"pk data N cannot be 1");
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
}
......@@ -1126,8 +1125,9 @@ struct ThreadwiseTensorSliceTransfer_v4
using src_vector_t = typename decltype(src_tmp_vector)::type;
//const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
//src_desc, src_data_coord);
// const bool is_src_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc,
// src_data_coord);
const bool is_src_valid = true;
// copy data from src_buf into src_tmp_vector
......
......@@ -80,14 +80,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
"SrcData != DstData");
static_assert(SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1");
static_assert(SrcScalarPerVector_ % PackedSize == 0 &&
DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1");
static_assert(
SrcVectorDim == DstVectorDim,
"pk_i4_t does not support transpose");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
}
}
......@@ -446,7 +445,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
else
{
constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
......@@ -875,8 +874,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
//static constexpr auto src_oob_thread_scratch_desc_ =
//decltype(GetSrcThreadScratchDescriptor()){};
// static constexpr auto src_oob_thread_scratch_desc_ =
// decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch =
......
......@@ -82,7 +82,7 @@ struct ReferenceGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
i4 = i4 - 8;
v_a = type_convert<ComputeTypeA>(i4);
}
else
......@@ -103,7 +103,7 @@ struct ReferenceGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
i4 = i4 - 8;
v_b = type_convert<ComputeTypeB>(i4);
}
else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment