Commit 2d50ecbe authored by mtgu0705's avatar mtgu0705
Browse files

fixed

parent 3eee7eda
...@@ -1146,36 +1146,7 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1146,36 +1146,7 @@ struct ThreadwiseTensorSliceTransfer_v4
}); });
} }
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value && if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
is_same<remove_cvref_t<DstData>, half_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
is_same<remove_cvref_t<DstData>, f8_t>::value)
{ {
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData) // DstData)
...@@ -1361,7 +1332,6 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1361,7 +1332,6 @@ struct ThreadwiseTensorSliceTransfer_v4
} }
else if constexpr(SrcBuffer::IsStaticBuffer()) else if constexpr(SrcBuffer::IsStaticBuffer())
{ {
static_assert(false, "");
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
...@@ -1371,10 +1341,8 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1371,10 +1341,8 @@ struct ThreadwiseTensorSliceTransfer_v4
}); });
} }
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value && if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
is_same<remove_cvref_t<DstData>, half_t>::value)
{ {
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData) // DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector; vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
...@@ -1405,34 +1373,6 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1405,34 +1373,6 @@ struct ThreadwiseTensorSliceTransfer_v4
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i]; dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
}); });
} }
else if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
is_same<remove_cvref_t<DstData>, f8_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value && else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value && is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0) SrcScalarPerVector % 2 == 0)
......
...@@ -8,18 +8,18 @@ ...@@ -8,18 +8,18 @@
#include <typeinfo> #include <typeinfo>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -162,9 +162,9 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -162,9 +162,9 @@ bool profile_gemm_b_scale_impl(int do_verification,
ck::pk_i4_t i4x2 = b_k_n(k, n).data; ck::pk_i4_t i4x2 = b_k_n(k, n).data;
int8_t i4 = 0; int8_t i4 = 0;
if(k % 2 == 1) if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2.data >> 0) & 0xf;
else else
i4 = (i4x2 >> 4) & 0xf; i4 = (i4x2.data >> 4) & 0xf;
i4 = i4 - 8; i4 = i4 - 8;
v_b = ck::type_convert<float>(i4); v_b = ck::type_convert<float>(i4);
...@@ -218,6 +218,8 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -218,6 +218,8 @@ bool profile_gemm_b_scale_impl(int do_verification,
} }
} }
if(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
{
// vector pk_i4x4 permute // vector pk_i4x4 permute
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
{ {
...@@ -267,15 +269,10 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -267,15 +269,10 @@ bool profile_gemm_b_scale_impl(int do_verification,
} }
} }
} }
}
else else
{ {
for(int i = 0; i < N; i++) b_k_n_permute = b_k_n;
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
} }
b_device_buf.ToDevice(b_k_n_permute.mData.data()); b_device_buf.ToDevice(b_k_n_permute.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment