Commit 0fd1d636 authored by Jing Zhang's avatar Jing Zhang
Browse files

fake conversion

parent d70f3a34
......@@ -66,7 +66,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
......
......@@ -21,6 +21,7 @@
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
......@@ -30,7 +31,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using BDataType = F8;
using AccDataType = F32;
using CDataType = F16;
......
......@@ -135,7 +135,7 @@
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
// set stochastic rounding as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1
#define CK_USE_SR_F8_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
......@@ -401,9 +401,9 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatA,
......@@ -415,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
ComputeTypeA,
ComputeTypeB>
{
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatA,
......@@ -427,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
ComputeTypeA,
ComputeTypeB>;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using Base::a_block_desc_m0_m1_m2_k;
......@@ -591,7 +595,9 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = ComputeTypeA>
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
......@@ -606,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
ComputeTypeA,
ComputeTypeB>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
......@@ -620,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
ComputeTypeA,
ComputeTypeB>{};
}
};
......
......@@ -825,7 +825,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MRepeat,
NRepeat,
K1,
LoopSched>();
LoopSched,
ComputeType>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......
......@@ -1160,6 +1160,8 @@ struct ThreadwiseTensorSliceTransfer_v4
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
});
}
#if 0
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
......@@ -1169,13 +1171,15 @@ struct ThreadwiseTensorSliceTransfer_v4
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
#endif
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
dst_buf(Number<dst_offset>{}) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
});
}
......
......@@ -365,7 +365,8 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
// return type_convert<half_t>(type_convert<float>(x));
return static_cast<half_t>(x);
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
......
......@@ -16,9 +16,10 @@ list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_in
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp)
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
)
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})
......@@ -200,8 +200,8 @@ bool profile_gemm_splitk_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, time_kernel, 0, 50, 200});
std::size_t flop = std::size_t(2) * M * N * K;
......
......@@ -187,6 +187,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
#if 0
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{});
......@@ -203,6 +204,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{});
}
#endif
#endif
else
{
......
......@@ -8,10 +8,10 @@ MY_PROJECT_SOURCE=$1
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_CXX_FLAGS="--save-temps -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
-D GPU_TARGETS="gfx942" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment