Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6a2521ea
Commit
6a2521ea
authored
Oct 23, 2024
by
Jing Zhang
Browse files
fixed splitk crush
parent
af2c0166
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
11 deletions
+12
-11
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+10
-7
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+1
-3
No files found.
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
View file @
6a2521ea
...
@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// weight permute
// weight permute
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
6a2521ea
...
@@ -389,7 +389,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -389,7 +389,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
else
else
{
{
// Weight Tile Permute
#ifndef WEIGHT_PERMUTE
#ifndef WEIGHT_PERMUTE
// not pad N or K
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
...
@@ -398,23 +397,27 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -398,23 +397,27 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
#else
#else
// Weight Tile Permute
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
const
index_t
BK00
=
BK0
/
BK01
;
const
index_t
BK0_
=
StrideB
/
BK1Value
;
const
index_t
BK00
=
BK0_
/
BK01
;
const
auto
b_grid_desc_bk00_n_bk01_bk1
=
const
auto
b_grid_desc_bk00_n_bk01_bk1
_permute
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
BK00
,
N
,
BK01
,
BK1Value
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
BK00
,
N
,
BK01
,
BK1Value
));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
const
auto
b_grid_desc_bk0_n_bk1
_permute
=
transform_tensor_descriptor
(
b_grid_desc_bk00_n_bk01_bk1
,
b_grid_desc_bk00_n_bk01_bk1
_permute
,
make_tuple
(
make_merge_transform
(
make_tuple
(
BK00
,
BK01
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
BK00
,
BK01
)),
make_pass_through_transform
(
make_tuple
(
N
)),
make_pass_through_transform
(
make_tuple
(
N
)),
make_pass_through_transform
(
BK1Value
)),
make_pass_through_transform
(
BK1Value
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#endif
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1_permute
;
#endif
}
}
}
}
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
6a2521ea
...
@@ -14,9 +14,7 @@ namespace ck {
...
@@ -14,9 +14,7 @@ namespace ck {
inline
__device__
int
amd_assembly_and_or_b32
(
int
a
,
int
b
,
int
d
)
inline
__device__
int
amd_assembly_and_or_b32
(
int
a
,
int
b
,
int
d
)
{
{
int
c
;
int
c
;
asm
volatile
(
"v_and_or_b32 %0, %1, %2, %3"
asm
volatile
(
"v_and_or_b32 %0, %1, %2, %3"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
d
));
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
d
));
return
c
;
return
c
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment