Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a6ccd2ec
Commit
a6ccd2ec
authored
Oct 28, 2024
by
Jing Zhang
Browse files
fixed int4 to bhalf_t conversion
parent
d642ce41
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
83 additions
and
57 deletions
+83
-57
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
+2
-2
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+9
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+5
-5
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+15
-0
library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt
...nsor_operation_instance/gpu/gemm_universal/CMakeLists.txt
+1
-0
profiler/include/profiler/profile_gemm_universal_impl.hpp
profiler/include/profiler/profile_gemm_universal_impl.hpp
+45
-42
profiler/src/profile_gemm_universal.cpp
profiler/src/profile_gemm_universal.cpp
+6
-1
No files found.
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
View file @
a6ccd2ec
...
...
@@ -65,7 +65,7 @@ using DeviceGemmV2Instance =
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
#endif
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v
1
,
CDataType
,
CDataType
,
false
,
PermuteB
>
;
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v
2
,
CDataType
,
CDataType
,
false
,
PermuteB
>
;
// clang-format on
...
...
@@ -146,7 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0
.0
,
1.0
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
}
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
a6ccd2ec
...
...
@@ -51,10 +51,11 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
}
__host__
__device__
inline
bhalf4_t
pki4_to_bhalf4
(
pk_i4x2_t
i4s
)
__host__
__device__
inline
bhalf4_t
pki4_to_bhalf4
(
int
q
)
{
uint32_t
q
=
bit_cast
<
uint16_t
>
(
i4s
);
uint32_t
i8s
=
(
q
&
0xf
)
|
(
q
&
0xf0
<<
4
)
|
(
q
&
0xf00
<<
8
)
|
(
q
&
0xf000
<<
12
);
uint32_t
i8s
=
(
q
&
0xf
)
|
((
q
&
0xf0
)
<<
4
)
|
((
q
&
0xf00
)
<<
8
)
|
((
q
&
0xf000
)
<<
12
);
//uint32_t i8s = q & 0xf0f0f0f;
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
float
fp32_intermediates
[
4
];
...
...
@@ -72,8 +73,8 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(pk_i4x2_t i4s)
fp32_intermediates
[
3
]
-=
8388616.
f
;
vector_type
<
bhalf_t
,
4
>
res
;
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
__byte_perm
(
fp32_intermediates_casted
[
1
],
fp32_intermediates_casted
[
2
],
0x7632
);
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
)
)
;
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
)
)
;
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
}
...
...
@@ -133,8 +134,9 @@ struct PassThroughPack8
#if 1
vector_type
<
bhalf_t
,
8
>
result
;
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
a6ccd2ec
...
...
@@ -40,11 +40,11 @@ __global__ void
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_c_grid
+
splitk_batch_offset
.
c_reduce_offset
,
p_shared
,
karg
);
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_c_grid
+
splitk_batch_offset
.
c_reduce_offset
,
p_shared
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
a6ccd2ec
...
...
@@ -177,6 +177,11 @@ void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
I4
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -827,6 +832,16 @@ struct DeviceOperationInstanceFactory<
}
}
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
pk_i4_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
...
...
library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt
View file @
a6ccd2ec
...
...
@@ -98,6 +98,7 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
...
...
profiler/include/profiler/profile_gemm_universal_impl.hpp
View file @
a6ccd2ec
...
...
@@ -175,51 +175,54 @@ bool profile_gemm_universal_impl(int do_verification,
}
}
// vector pk_i4x4 permute
for
(
int
i
=
0
;
i
<
N
;
i
++
)
if
(
is_same_v
<
BDataType
,
pk_i4_t
>
&&
is_same_v
<
ADataType
,
half_t
>
)
{
for
(
int
j
=
0
;
j
<
K
;
j
+=
8
)
// vector pk_i4x4 permute
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
int
input
[
8
];
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
i4x2
=
b_k_n_permute
(
j
+
k
*
2
,
i
);
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
}
// permute 01234567->20643175
{
int
hi
=
input
[
2
];
int
lo
=
input
[
0
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
0
,
i
)
=
i4x2
;
}
for
(
int
j
=
0
;
j
<
K
;
j
+=
8
)
{
int
hi
=
input
[
6
];
int
lo
=
input
[
4
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
2
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
3
];
int
lo
=
input
[
1
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
4
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
7
];
int
lo
=
input
[
5
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
6
,
i
)
=
i4x2
;
int
input
[
8
];
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
i4x2
=
b_k_n_permute
(
j
+
k
*
2
,
i
);
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
}
// permute 01234567->20643175
{
int
hi
=
input
[
2
];
int
lo
=
input
[
0
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
0
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
6
];
int
lo
=
input
[
4
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
2
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
3
];
int
lo
=
input
[
1
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
4
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
7
];
int
lo
=
input
[
5
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
6
,
i
)
=
i4x2
;
}
}
}
}
...
...
profiler/src/profile_gemm_universal.cpp
View file @
a6ccd2ec
...
...
@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8
,
// 6
F8_F8_BF16
,
// 7
F16_I4_F16
,
// 8
BF16_I4_BF16
,
// 9
};
#define OP_NAME "gemm_universal"
...
...
@@ -40,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[])
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: f16@i4
)
\n
"
);
"comp f8; 8: f16@i4
; 9: bf16@i4
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
...
@@ -193,6 +194,10 @@ int profile_gemm_universal(int argc, char* argv[])
{
return
profile
(
F16
{},
I4
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_I4_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
BF16
{},
I4
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment