Commit 14099622 authored by coderfeli's avatar coderfeli
Browse files

fix quant 8192 err & change norm_reduce class and file name

parent aef2b33c
...@@ -516,10 +516,6 @@ include_directories(BEFORE ...@@ -516,10 +516,6 @@ include_directories(BEFORE
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
...@@ -66,7 +66,6 @@ else() ...@@ -66,7 +66,6 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
...@@ -302,6 +302,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -302,6 +302,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host); ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host); ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
} }
else if(init == 3)
{
ck_tile::FillConstant<ADataType>{}(a_host);
ck_tile::FillConstant<GDataType>{}(g_host);
ck_tile::FillConstant<DDataType>{}(d_host);
ck_tile::FillConstant<AScaleDataType>{}(sa_host);
ck_tile::FillConstant<GScaleDataType>{}(sg_host);
ck_tile::FillConstant<DScaleDataType>{}(sd_host);
ck_tile::FillConstant<YSmoothScaleDataType>{}(sy_host);
ck_tile::FillConstant<TopkWeightDataType>{}(topk_weight_host);
}
// permute weight // permute weight
ck_tile::HostTensor<GDataType> g_perm_host = gate_only? shuffle_moe_weight(g_host, prec_w, 1) : shuffle_moe_weight_gateup(g_host, prec_w, 1); ck_tile::HostTensor<GDataType> g_perm_host = gate_only? shuffle_moe_weight(g_host, prec_w, 1) : shuffle_moe_weight_gateup(g_host, prec_w, 1);
...@@ -322,7 +333,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -322,7 +333,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else else
{ {
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) { for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) {
topk_ids_host.mData[i] = i % 4; topk_ids_host.mData[i] = 0;
} }
// topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913); // topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
} }
...@@ -486,7 +497,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -486,7 +497,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.savetxt("num_sorted_tiles_host.txt", "int"); num_sorted_tiles_host.savetxt("num_sorted_tiles_host.txt", "int");
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>(); auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err( pass &= ck_tile::check_err(
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
...@@ -595,7 +606,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -595,7 +606,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
gate_only); gate_only);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>(); auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err( pass &= ck_tile::check_err(
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
......
...@@ -76,7 +76,7 @@ check_err(const Range& out, ...@@ -76,7 +76,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 32)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -136,7 +136,7 @@ check_err(const Range& out, ...@@ -136,7 +136,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 32)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -195,7 +195,7 @@ check_err(const Range& out, ...@@ -195,7 +195,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 32)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -250,7 +250,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -250,7 +250,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 32)
{ {
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl; << std::endl;
...@@ -327,7 +327,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -327,7 +327,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 32)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl; << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
...@@ -381,7 +381,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -381,7 +381,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 32)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
...@@ -339,7 +339,7 @@ struct FillStepRange ...@@ -339,7 +339,7 @@ struct FillStepRange
template <typename T> template <typename T>
struct FillConstant struct FillConstant
{ {
T value_{0}; T value_{type_convert<T>(1.0f)};
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
......
...@@ -19,7 +19,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_Base ...@@ -19,7 +19,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_Base
static constexpr index_t Block_K = 256; static constexpr index_t Block_K = 256;
static constexpr index_t WarpPerBlock_M = 1; static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4; static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1; static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t Warp_M = 16; static constexpr index_t Warp_M = 16;
......
...@@ -85,6 +85,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_ ...@@ -85,6 +85,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
register float v_c29 asm("v93"); register float v_c29 asm("v93");
register float v_c30 asm("v94"); register float v_c30 asm("v94");
register float v_c31 asm("v95"); register float v_c31 asm("v95");
register bf16x2_t v_debug asm("v160");
register bf16x2_t v_debug1 asm("v161");
register bf16x2_t v_debug2 asm("v162");
register bf16x2_t v_debug3 asm("v163");
register bf16x2_t v_debug4 asm("v164");
register bf16x2_t v_debug5 asm("v165");
register bf16x2_t v_debug6 asm("v166");
register bf16x2_t v_debug7 asm("v167");
int32_t nan_hi = 0x7fff0000; int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff; int32_t nan_lo = 0x00007fff;
...@@ -154,7 +162,15 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_ ...@@ -154,7 +162,15 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
[c28]"+v"(v_c28), [c28]"+v"(v_c28),
[c29]"+v"(v_c29), [c29]"+v"(v_c29),
[c30]"+v"(v_c30), [c30]"+v"(v_c30),
[c31]"+v"(v_c31) [c31]"+v"(v_c31),
[debug0]"+v"(v_debug),
[debug1]"+v"(v_debug1),
[debug2]"+v"(v_debug2),
[debug3]"+v"(v_debug3),
[debug4]"+v"(v_debug4),
[debug5]"+v"(v_debug5),
[debug6]"+v"(v_debug6),
[debug7]"+v"(v_debug7)
: :
[sld_a_base]"n"(0), [sld_a_base]"n"(0),
[shfl_base]"n"(0), [shfl_base]"n"(0),
...@@ -259,6 +275,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_ ...@@ -259,6 +275,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
); );
#pragma clang diagnostic pop #pragma clang diagnostic pop
// clang-format on // clang-format on
if(1) {
printf("\n%d %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f\n",
threadIdx.x,
type_convert<float>(v_debug.x), type_convert<float>(v_debug.y),
type_convert<float>(v_debug1.x), type_convert<float>(v_debug1.y),
type_convert<float>(v_debug2.x), type_convert<float>(v_debug2.y),
type_convert<float>(v_debug3.x), type_convert<float>(v_debug3.y));
}
} }
}; };
......
...@@ -111,6 +111,7 @@ ...@@ -111,6 +111,7 @@
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:23168 \n" " ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:23168 \n"
" s_mov_b32 s80, 0 \n" " s_mov_b32 s80, 0 \n"
" s_waitcnt vmcnt(8) \n" " s_waitcnt vmcnt(8) \n"
" s_waitcnt vmcnt(0) & lgkmcnt(0) \n"
"coreloop_top_%=: \n" "coreloop_top_%=: \n"
" s_waitcnt vmcnt(0) & lgkmcnt(0) \n" " s_waitcnt vmcnt(0) & lgkmcnt(0) \n"
" s_barrier \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n" " s_barrier \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n"
......
...@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize(); constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0; BlockShape::Block_M0 * BlockShape::Block_N0;
return max(smem_0, max(smem_1, smem_bridge)); return 32768;//max(smem_0, max(smem_1, smem_bridge));
} }
// this is the thread-offset along row/col // this is the thread-offset along row/col
...@@ -329,8 +329,15 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -329,8 +329,15 @@ struct FusedMoeGemmPipeline_FlatmmUk
BlockShape::Block_K0, // tile offset for B matrix each unroll BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll BlockShape::Block_W0); // tile offset for B matrix each unroll
// for(auto i = 0; i < 8; i++)
// {
// if(threadIdx.x==0) {
// printf("%d, %.1f, %.1f, %.1f, %.1f\n",i, acc_0_full.get_thread_buffer()[4 * (i) + 0], acc_0_full.get_thread_buffer()[4 * (i) + 1], acc_0_full.get_thread_buffer()[4 * (i) + 2], acc_0_full.get_thread_buffer()[4 * (i) + 3]);
// }
// }
// auto acc_0 = IsGateOnly ? acc_0_full : Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge(); // auto acc_0 = IsGateOnly ? acc_0_full : Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge();
auto acc_0 = Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge(); auto acc_0 = Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge();
if (!IsGateOnly) { if (!IsGateOnly) {
sweep_tile(acc_0, [&](auto idx0) { sweep_tile(acc_0, [&](auto idx0) {
acc_0(idx0) = acc_0_full(idx0); acc_0(idx0) = acc_0_full(idx0);
...@@ -359,7 +366,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -359,7 +366,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
} }
if (!IsGateOnly) { if (!IsGateOnly) {
for(auto i = 0; i < BlockShape::Repeat_N0; i++) for(auto i = 0; i < BlockShape::Repeat_N0 * BlockShape::Repeat_M0; i++)
{ {
acc_0.get_thread_buffer()[4 * i + 0] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 0]; acc_0.get_thread_buffer()[4 * i + 0] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 0];
acc_0.get_thread_buffer()[4 * i + 1] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 1]; acc_0.get_thread_buffer()[4 * i + 1] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 1];
...@@ -367,10 +374,15 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -367,10 +374,15 @@ struct FusedMoeGemmPipeline_FlatmmUk
acc_0.get_thread_buffer()[4 * i + 3] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 3]; acc_0.get_thread_buffer()[4 * i + 3] *= acc_0_full.get_thread_buffer()[4 * (i + BlockShape::Repeat_N0) + 3];
} }
} }
auto y_pre = acc_0;
block_sync_lds(); block_sync_lds();
store_tile(bridge_sst_win, cast_tile<YDataType>(y_pre)); store_tile(bridge_sst_win, cast_tile<YDataType>(acc_0));
block_sync_lds(); block_sync_lds();
// YDataType *smemy = reinterpret_cast<YDataType *>(smem);
// if(threadIdx.x==0) {
// for (int i = 0; i<32 * 256; i++) {
// printf("%.1f,", type_convert<float>(smemy[i]));
// }}
// block_sync_lds();
auto uk_1 = Policy::template GetUK_1<Problem>(); auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(d_res, uk_1(d_res,
......
...@@ -17,7 +17,7 @@ fi ...@@ -17,7 +17,7 @@ fi
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm/ \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -g -v --save-temps -Wno-gnu-line-marker " \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment