Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
f8ed456e
Unverified
Commit
f8ed456e
authored
Aug 17, 2023
by
Li Zhang
Committed by
GitHub
Aug 17, 2023
Browse files
[Fix] Implement movmatrix using warp shuffling for CUDA < 11.8 (#267)
parent
903707b5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
2 deletions
+33
-2
src/turbomind/kernels/gemm_s_f16/gemm_template.h
src/turbomind/kernels/gemm_s_f16/gemm_template.h
+33
-2
No files found.
src/turbomind/kernels/gemm_s_f16/gemm_template.h
View file @
f8ed456e
...
@@ -26,7 +26,26 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
...
@@ -26,7 +26,26 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
#endif
#endif
}
}
__inline__
__device__
uint
transpose_m8n8_b16
(
uint
a
)
__inline__
__device__
uint
transpose_m8n8_b16_warp_shuffle
(
uint
value
,
int
lane_id
)
{
int
src_lane
=
lane_id
/
8
+
lane_id
%
4
*
8
;
uint
u0
=
__shfl_sync
(
0xffffffff
,
value
,
src_lane
);
uint
u1
=
__shfl_sync
(
0xffffffff
,
value
,
src_lane
+
4
);
short2
r
;
if
(
lane_id
%
8
<
4
)
{
r
.
x
=
((
short2
&
)
u0
).
x
;
r
.
y
=
((
short2
&
)
u1
).
x
;
}
else
{
r
.
x
=
((
short2
&
)
u0
).
y
;
r
.
y
=
((
short2
&
)
u1
).
y
;
}
return
(
uint
&
)
r
;
}
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
__inline__
__device__
uint
transpose_m8n8_b16_movmatrix
(
uint
a
)
{
{
#if TURBOMIND_ARCH_SM75
#if TURBOMIND_ARCH_SM75
uint
d
;
uint
d
;
...
@@ -37,6 +56,18 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a)
...
@@ -37,6 +56,18 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a)
return
0
;
return
0
;
#endif
#endif
}
}
#endif
__inline__
__device__
uint
transpose_m8n8_b16
(
uint
a
,
int
lane_id
)
{
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
(
void
)
lane_id
;
return
transpose_m8n8_b16_movmatrix
(
a
);
#else
return
transpose_m8n8_b16_warp_shuffle
(
a
,
lane_id
);
#endif
}
namespace
ops
{
namespace
ops
{
...
@@ -246,7 +277,7 @@ struct Gemm {
...
@@ -246,7 +277,7 @@ struct Gemm {
// convert to half
// convert to half
half2
half_C
=
__float22half2_rn
(
frag_C
[
j
*
2
+
x
]);
half2
half_C
=
__float22half2_rn
(
frag_C
[
j
*
2
+
x
]);
// transpose 8x8 accum tile
// transpose 8x8 accum tile
uint
trans_C
=
transpose_m8n8_b16
((
uint
&
)
half_C
);
uint
trans_C
=
transpose_m8n8_b16
((
uint
&
)
half_C
,
lane_id
);
// store to global memory
// store to global memory
OutputOps
::
template
apply
<
Index
>(
trans_C
,
mm
,
nn
,
C
,
m
,
n
);
OutputOps
::
template
apply
<
Index
>(
trans_C
,
mm
,
nn
,
C
,
m
,
n
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment