Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
14099622
Commit
14099622
authored
Jan 13, 2025
by
coderfeli
Browse files
fix quant 8192 err & change norm_reduce class and file name
parent
aef2b33c
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
65 additions
and
22 deletions
+65
-22
CMakeLists.txt
CMakeLists.txt
+0
-4
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+0
-1
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+14
-3
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+6
-6
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+1
-1
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32.hpp
.../ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32.hpp
+1
-1
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
.../flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
+25
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc
...ck/uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc
+1
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+16
-4
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
No files found.
CMakeLists.txt
View file @
14099622
...
@@ -516,10 +516,6 @@ include_directories(BEFORE
...
@@ -516,10 +516,6 @@ include_directories(BEFORE
)
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
if
(
BUILD_DEV
)
add_compile_options
(
-Werror
)
add_compile_options
(
-Weverything
)
endif
()
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
...
...
cmake/EnableCompilerWarnings.cmake
View file @
14099622
...
@@ -66,7 +66,6 @@ else()
...
@@ -66,7 +66,6 @@ else()
-Wunreachable-code
-Wunreachable-code
-Wunused
-Wunused
-Wno-reserved-identifier
-Wno-reserved-identifier
-Werror
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
...
...
example/ck_tile/15_fused_moe/main.cpp
View file @
14099622
...
@@ -302,6 +302,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -302,6 +302,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillNormalDistribution
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillNormalDistribution
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillNormalDistribution
<
TopkWeightDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
topk_weight_host
);
ck_tile
::
FillNormalDistribution
<
TopkWeightDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
topk_weight_host
);
}
}
else
if
(
init
==
3
)
{
ck_tile
::
FillConstant
<
ADataType
>
{}(
a_host
);
ck_tile
::
FillConstant
<
GDataType
>
{}(
g_host
);
ck_tile
::
FillConstant
<
DDataType
>
{}(
d_host
);
ck_tile
::
FillConstant
<
AScaleDataType
>
{}(
sa_host
);
ck_tile
::
FillConstant
<
GScaleDataType
>
{}(
sg_host
);
ck_tile
::
FillConstant
<
DScaleDataType
>
{}(
sd_host
);
ck_tile
::
FillConstant
<
YSmoothScaleDataType
>
{}(
sy_host
);
ck_tile
::
FillConstant
<
TopkWeightDataType
>
{}(
topk_weight_host
);
}
// permute weight
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
gate_only
?
shuffle_moe_weight
(
g_host
,
prec_w
,
1
)
:
shuffle_moe_weight_gateup
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
gate_only
?
shuffle_moe_weight
(
g_host
,
prec_w
,
1
)
:
shuffle_moe_weight_gateup
(
g_host
,
prec_w
,
1
);
...
@@ -322,7 +333,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -322,7 +333,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
else
{
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
{
topk_ids_host
.
mData
[
i
]
=
i
%
4
;
topk_ids_host
.
mData
[
i
]
=
0
;
}
}
// topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
// topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
}
...
@@ -486,7 +497,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -486,7 +497,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host
.
savetxt
(
"num_sorted_tiles_host.txt"
,
"int"
);
num_sorted_tiles_host
.
savetxt
(
"num_sorted_tiles_host.txt"
,
"int"
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
//
o_dev.savetxt("gpu-out.txt", "float");
o_dev
.
savetxt
(
"gpu-out.txt"
,
"float"
);
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
...
@@ -595,7 +606,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -595,7 +606,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
gate_only
);
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
//
o_dev.savetxt("gpu-out.txt", "float");
o_dev
.
savetxt
(
"gpu-out.txt"
,
"float"
);
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
...
...
include/ck_tile/host/check_err.hpp
View file @
14099622
...
@@ -76,7 +76,7 @@ check_err(const Range& out,
...
@@ -76,7 +76,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
32
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
@@ -136,7 +136,7 @@ check_err(const Range& out,
...
@@ -136,7 +136,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
32
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
@@ -195,7 +195,7 @@ check_err(const Range& out,
...
@@ -195,7 +195,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
32
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
@@ -250,7 +250,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -250,7 +250,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
32
)
{
{
std
::
cerr
<<
msg
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
std
::
cerr
<<
msg
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -327,7 +327,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -327,7 +327,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
32
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o_fp64
<<
" != "
<<
r_fp64
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o_fp64
<<
" != "
<<
r_fp64
<<
std
::
endl
;
...
@@ -381,7 +381,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -381,7 +381,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
32
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
include/ck_tile/host/fill.hpp
View file @
14099622
...
@@ -339,7 +339,7 @@ struct FillStepRange
...
@@ -339,7 +339,7 @@ struct FillStepRange
template
<
typename
T
>
template
<
typename
T
>
struct
FillConstant
struct
FillConstant
{
{
T
value_
{
0
};
T
value_
{
type_convert
<
T
>
(
1.0
f
)
};
template
<
typename
ForwardIter
>
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32.hpp
View file @
14099622
...
@@ -19,7 +19,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_Base
...
@@ -19,7 +19,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_Base
static
constexpr
index_t
Block_K
=
256
;
static
constexpr
index_t
Block_K
=
256
;
static
constexpr
index_t
WarpPerBlock_M
=
1
;
static
constexpr
index_t
WarpPerBlock_M
=
1
;
static
constexpr
index_t
WarpPerBlock_N
=
4
;
static
constexpr
index_t
WarpPerBlock_N
=
4
;
static
constexpr
index_t
WarpPerBlock_K
=
1
;
static
constexpr
index_t
WarpPerBlock_K
=
1
;
static
constexpr
index_t
Warp_M
=
16
;
static
constexpr
index_t
Warp_M
=
16
;
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
View file @
14099622
...
@@ -85,6 +85,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
...
@@ -85,6 +85,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
register
float
v_c29
asm
(
"v93"
);
register
float
v_c29
asm
(
"v93"
);
register
float
v_c30
asm
(
"v94"
);
register
float
v_c30
asm
(
"v94"
);
register
float
v_c31
asm
(
"v95"
);
register
float
v_c31
asm
(
"v95"
);
register
bf16x2_t
v_debug
asm
(
"v160"
);
register
bf16x2_t
v_debug1
asm
(
"v161"
);
register
bf16x2_t
v_debug2
asm
(
"v162"
);
register
bf16x2_t
v_debug3
asm
(
"v163"
);
register
bf16x2_t
v_debug4
asm
(
"v164"
);
register
bf16x2_t
v_debug5
asm
(
"v165"
);
register
bf16x2_t
v_debug6
asm
(
"v166"
);
register
bf16x2_t
v_debug7
asm
(
"v167"
);
int32_t
nan_hi
=
0x7fff0000
;
int32_t
nan_hi
=
0x7fff0000
;
int32_t
nan_lo
=
0x00007fff
;
int32_t
nan_lo
=
0x00007fff
;
...
@@ -154,7 +162,15 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
...
@@ -154,7 +162,15 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
[
c28
]
"+v"
(
v_c28
),
[
c28
]
"+v"
(
v_c28
),
[
c29
]
"+v"
(
v_c29
),
[
c29
]
"+v"
(
v_c29
),
[
c30
]
"+v"
(
v_c30
),
[
c30
]
"+v"
(
v_c30
),
[
c31
]
"+v"
(
v_c31
)
[
c31
]
"+v"
(
v_c31
),
[
debug0
]
"+v"
(
v_debug
),
[
debug1
]
"+v"
(
v_debug1
),
[
debug2
]
"+v"
(
v_debug2
),
[
debug3
]
"+v"
(
v_debug3
),
[
debug4
]
"+v"
(
v_debug4
),
[
debug5
]
"+v"
(
v_debug5
),
[
debug6
]
"+v"
(
v_debug6
),
[
debug7
]
"+v"
(
v_debug7
)
:
:
[
sld_a_base
]
"n"
(
0
),
[
sld_a_base
]
"n"
(
0
),
[
shfl_base
]
"n"
(
0
),
[
shfl_base
]
"n"
(
0
),
...
@@ -259,6 +275,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
...
@@ -259,6 +275,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
);
);
#pragma clang diagnostic pop
#pragma clang diagnostic pop
// clang-format on
// clang-format on
if
(
1
)
{
printf
(
"
\n
%d %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f
\n
"
,
threadIdx
.
x
,
type_convert
<
float
>
(
v_debug
.
x
),
type_convert
<
float
>
(
v_debug
.
y
),
type_convert
<
float
>
(
v_debug1
.
x
),
type_convert
<
float
>
(
v_debug1
.
y
),
type_convert
<
float
>
(
v_debug2
.
x
),
type_convert
<
float
>
(
v_debug2
.
y
),
type_convert
<
float
>
(
v_debug3
.
x
),
type_convert
<
float
>
(
v_debug3
.
y
));
}
}
}
};
};
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc
View file @
14099622
...
@@ -111,6 +111,7 @@
...
@@ -111,6 +111,7 @@
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:23168
\n
"
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:23168
\n
"
" s_mov_b32 s80, 0
\n
"
" s_mov_b32 s80, 0
\n
"
" s_waitcnt vmcnt(8)
\n
"
" s_waitcnt vmcnt(8)
\n
"
" s_waitcnt vmcnt(0) & lgkmcnt(0)
\n
"
"coreloop_top_%=:
\n
"
"coreloop_top_%=:
\n
"
" s_waitcnt vmcnt(0) & lgkmcnt(0)
\n
"
" s_waitcnt vmcnt(0) & lgkmcnt(0)
\n
"
" s_barrier
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0
\n
"
" s_barrier
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0
\n
"
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
14099622
...
@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
;
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
;
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
return
32768
;
//
max(smem_0, max(smem_1, smem_bridge));
}
}
// this is the thread-offset along row/col
// this is the thread-offset along row/col
...
@@ -329,8 +329,15 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -329,8 +329,15 @@ struct FusedMoeGemmPipeline_FlatmmUk
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
// for(auto i = 0; i < 8; i++)
// {
// if(threadIdx.x==0) {
// printf("%d, %.1f, %.1f, %.1f, %.1f\n",i, acc_0_full.get_thread_buffer()[4 * (i) + 0], acc_0_full.get_thread_buffer()[4 * (i) + 1], acc_0_full.get_thread_buffer()[4 * (i) + 2], acc_0_full.get_thread_buffer()[4 * (i) + 3]);
// }
// }
// auto acc_0 = IsGateOnly ? acc_0_full : Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge();
// auto acc_0 = IsGateOnly ? acc_0_full : Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge();
auto
acc_0
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockTileGUMerge
();
auto
acc_0
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockTileGUMerge
();
if
(
!
IsGateOnly
)
{
if
(
!
IsGateOnly
)
{
sweep_tile
(
acc_0
,
[
&
](
auto
idx0
)
{
sweep_tile
(
acc_0
,
[
&
](
auto
idx0
)
{
acc_0
(
idx0
)
=
acc_0_full
(
idx0
);
acc_0
(
idx0
)
=
acc_0_full
(
idx0
);
...
@@ -359,7 +366,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -359,7 +366,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
}
}
if
(
!
IsGateOnly
)
{
if
(
!
IsGateOnly
)
{
for
(
auto
i
=
0
;
i
<
BlockShape
::
Repeat_N0
;
i
++
)
for
(
auto
i
=
0
;
i
<
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_M0
;
i
++
)
{
{
acc_0
.
get_thread_buffer
()[
4
*
i
+
0
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
0
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
0
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
0
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
1
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
1
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
1
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
1
];
...
@@ -367,10 +374,15 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -367,10 +374,15 @@ struct FusedMoeGemmPipeline_FlatmmUk
acc_0
.
get_thread_buffer
()[
4
*
i
+
3
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
3
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
3
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
3
];
}
}
}
}
auto
y_pre
=
acc_0
;
block_sync_lds
();
block_sync_lds
();
store_tile
(
bridge_sst_win
,
cast_tile
<
YDataType
>
(
y_pre
));
store_tile
(
bridge_sst_win
,
cast_tile
<
YDataType
>
(
acc_0
));
block_sync_lds
();
block_sync_lds
();
// YDataType *smemy = reinterpret_cast<YDataType *>(smem);
// if(threadIdx.x==0) {
// for (int i = 0; i<32 * 256; i++) {
// printf("%.1f,", type_convert<float>(smemy[i]));
// }}
// block_sync_lds();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
uk_1
(
d_res
,
uk_1
(
d_res
,
...
...
script/cmake-ck-dev.sh
View file @
14099622
...
@@ -17,7 +17,7 @@ fi
...
@@ -17,7 +17,7 @@ fi
cmake
\
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm/
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm/
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker
-g -v --save-temps -Wno-gnu-line-marker
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment