Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
66873f3a
Commit
66873f3a
authored
Nov 13, 2024
by
Jing Zhang
Browse files
clean
parent
23bdf72a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
60 deletions
+9
-60
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
+1
-52
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
+4
-2
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+4
-6
No files found.
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
View file @
66873f3a
...
...
@@ -158,7 +158,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
()
/
2
);
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// weight permute
...
...
@@ -190,57 +190,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
#if 0
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
}
}
}
#endif
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_permute
.
mData
.
data
());
DeviceMem
workspace
;
...
...
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
View file @
66873f3a
...
...
@@ -54,7 +54,7 @@ using DeviceGemmV2Instance =
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
#endif
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v2
,
C
DataType
,
C
DataType
,
false
,
PermuteB
>
;
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v2
,
A
DataType
,
A
DataType
,
false
,
PermuteB
>
;
// clang-format on
...
...
@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
()
/
2
);
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// weight permute
...
...
@@ -179,6 +179,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
#if 1
// vector pk_i4x4 permute
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
...
...
@@ -227,6 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
}
#endif
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_permute
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
66873f3a
...
...
@@ -39,14 +39,12 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
int
x_l
=
(
x_u8
&
0x0f
);
int
x_h
=
(
x_u8
&
0xf0
)
<<
12
;
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
(
x_l
|
x_h
)
|
EX
;
int
lo
=
i4s
|
EX
;
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
}
...
...
@@ -84,8 +82,8 @@ __host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
float
x_h
=
((
x_u8
&
0x0f
)
>>
0
)
-
8
;
float
x_l
=
((
x_u8
&
0xf0
)
>>
4
)
-
8
;
float
x_h
=
((
x_u8
&
0x0f
)
>>
0
)
-
8
.
f
;
float
x_l
=
((
x_u8
&
0xf0
)
>>
4
)
-
8
.
f
;
vector_type
<
bhalf_t
,
2
>
res
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment