Commit 3a84f68e authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent bf545630
......@@ -37,6 +37,7 @@ struct DeviceGemmV2 : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual bool GetPermuteA() = 0;
virtual bool GetPermuteB() = 0;
virtual ck::index_t GetKPerBlock() = 0;
};
......
......@@ -410,7 +410,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr index_t BK01 = KPerBlock / BK1Value;
// const index_t BK00 = BK0 / BK01;
const index_t BK0_ = StrideB / BK1Value;
const index_t BK00 = BK0_ / BK01;
......
......@@ -1137,7 +1137,6 @@ struct ThreadwiseTensorSliceTransfer_v4
}
else if constexpr(SrcBuffer::IsStaticBuffer())
{
static_assert(false, "");
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
......
......@@ -82,9 +82,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(SrcScalarPerVector_ % PackedSize == 0 &&
DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1");
static_assert(
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
}
......@@ -234,8 +234,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
static_assert(elem_op_vec_len == 1, "elem_op_vec_len != 1");
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
......@@ -300,13 +298,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_assert(false, "");
static_ford<SliceLengths>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
#else
#if 1
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
......@@ -369,7 +364,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
});
#endif
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
......@@ -381,9 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
// static_assert(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
//"transpose is not allowed for pk_i4_t");
#if 1
static_assert(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"transpose is not allowed for pk_i4_t");
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
......@@ -441,7 +434,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
#endif
}
else
{
......
......@@ -429,6 +429,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
......
......@@ -1893,6 +1893,14 @@ using bf8x32_t = bf8x32_fnuz_t;
using bf8x64_t = bf8x64_fnuz_t;
#endif
// u8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
......
......@@ -29,13 +29,6 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
......@@ -59,7 +52,11 @@ struct DynamicBuffer
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X>
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{
// X contains multiple T
......@@ -85,18 +82,14 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_ / PackedSize);
p_data_, i, is_valid_element, element_space_size_);
}
else
{
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
}
}
else
......@@ -198,10 +191,14 @@ struct DynamicBuffer
dst_buf.p_data_,
dst_offset,
is_valid_element,
element_space_size_ / PackedSize);
element_space_size_);
}
template <typename X>
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
......@@ -229,7 +226,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
x, p_data_, i, is_valid_element, element_space_size_);
}
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
......@@ -381,7 +378,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
x, p_data_, i, is_valid_element, element_space_size_);
}
else
{
......@@ -420,7 +417,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
x, p_data_, i, is_valid_element, element_space_size_);
}
else if(is_valid_element)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment