Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d6a7acf6
Commit
d6a7acf6
authored
Oct 29, 2024
by
Jing Zhang
Browse files
fixed int4 to bf16 conversion
parent
9de3a085
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
6 deletions
+24
-6
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+18
-0
No files found.
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
View file @
d6a7acf6
...
@@ -65,7 +65,7 @@ using DeviceGemmV2Instance =
...
@@ -65,7 +65,7 @@ using DeviceGemmV2Instance =
2
,
32
,
32
,
0
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
#endif
#endif
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v2
,
C
DataType
,
C
DataType
,
false
,
PermuteB
>
;
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v2
,
A
DataType
,
A
DataType
,
false
,
PermuteB
>
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
d6a7acf6
...
@@ -73,10 +73,10 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
...
@@ -73,10 +73,10 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
fp32_intermediates
[
3
]
-=
8388616.
f
;
fp32_intermediates
[
3
]
-=
8388616.
f
;
vector_type
<
bhalf_t
,
4
>
res
;
vector_type
<
bhalf_t
,
4
>
res
;
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
));
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
bit_cast
<
bhalf2_t
>
(
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
));
__byte_perm
(
fp32_intermediates_casted
[
1
],
fp32_intermediates_casted
[
0
],
0x7632
));
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
3
],
fp32_intermediates_casted
[
2
],
0x7632
));
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
}
}
...
@@ -135,8 +135,8 @@ struct PassThroughPack8
...
@@ -135,8 +135,8 @@ struct PassThroughPack8
#if 1
#if 1
vector_type
<
bhalf_t
,
8
>
result
;
vector_type
<
bhalf_t
,
8
>
result
;
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#else
#else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
d6a7acf6
...
@@ -45,6 +45,24 @@ __global__ void
...
@@ -45,6 +45,24 @@ __global__ void
karg
.
p_c_grid
+
splitk_batch_offset
.
c_reduce_offset
,
karg
.
p_c_grid
+
splitk_batch_offset
.
c_reduce_offset
,
p_shared
,
p_shared
,
karg
);
karg
);
// int q = 0x01234567;
// ck::vector_type<ck::bhalf_t, 8> res;
// res.template AsType<ck::bhalf4_t>()(ck::Number<0>{}) = ck::pki4_to_bhalf4(q >> 16);
// res.template AsType<ck::bhalf4_t>()(ck::Number<1>{}) = ck::pki4_to_bhalf4(q);
// if(threadIdx.x == 0 && blockIdx.x == 0)
// printf("%f %f %f %f %f %f %f %f\n",
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<0>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<1>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<2>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<3>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<4>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<5>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<6>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<7>{}])
//);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
#endif // end of if (defined(__gfx9__))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment