Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6e7218fb
Commit
6e7218fb
authored
Jan 01, 2025
by
mtgu0705
Browse files
Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue
parent
a454ab1d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
+7
-7
No files found.
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
View file @
6e7218fb
...
@@ -77,21 +77,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -77,21 +77,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
};
};
auto
f_get_default_stride
=
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size
_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index
_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
if
(
stride
==
-
1
)
{
{
// give a chance if stride is
zero
, return a default packed stride
// give a chance if stride is
-1
, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
col
;
return
static_cast
<
std
::
size_t
>
(
col
)
;
}
}
else
else
{
{
return
row
;
return
static_cast
<
std
::
size_t
>
(
row
)
;
}
}
}
}
else
else
return
stride
;
return
static_cast
<
std
::
size_t
>
(
stride
)
;
};
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
...
@@ -133,7 +133,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -133,7 +133,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// weight permute
// weight permute
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment