Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bfcac2f9
Commit
bfcac2f9
authored
Nov 14, 2024
by
Jing Zhang
Browse files
fixed
parent
b0b67669
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
3 deletions
+3
-3
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
+1
-1
profiler/include/profiler/profile_gemm_universal_impl.hpp
profiler/include/profiler/profile_gemm_universal_impl.hpp
+2
-2
No files found.
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
View file @
bfcac2f9
...
@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
()
/
2
);
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// weight permute
// weight permute
...
...
profiler/include/profiler/profile_gemm_universal_impl.hpp
View file @
bfcac2f9
...
@@ -69,7 +69,7 @@ bool profile_gemm_universal_impl(int do_verification,
...
@@ -69,7 +69,7 @@ bool profile_gemm_universal_impl(int do_verification,
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
in
t
total_gemm_needed
=
a_m_k
.
GetElementSpaceSizeInBytes
()
+
b_k_n
.
GetElementSpaceSizeInBytes
();
std
::
size_
t
total_gemm_needed
=
a_m_k
.
GetElementSpaceSizeInBytes
()
+
b_k_n
.
GetElementSpaceSizeInBytes
();
int
rotating_count
=
std
::
max
(
int
rotating_count
=
std
::
max
(
1
,
1
,
std
::
min
(
n_iter
,
std
::
min
(
n_iter
,
...
@@ -105,7 +105,7 @@ bool profile_gemm_universal_impl(int do_verification,
...
@@ -105,7 +105,7 @@ bool profile_gemm_universal_impl(int do_verification,
const
auto
c_element_op
=
CElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
()
/
(
is_same_v
<
BDataType
,
ck
::
pk_i4_t
>
?
2
:
1
)
);
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment