Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
89123dd7
Commit
89123dd7
authored
Jan 06, 2021
by
Chao Liu
Browse files
refactor
parent
76df7392
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
64 deletions
+72
-64
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+72
-64
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
89123dd7
...
@@ -68,13 +68,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -68,13 +68,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
constexpr
index_t
a_block_space
_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
constexpr
index_t
b_block_space
_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
return
2
*
(
a_block_space
+
b_block_space
)
*
sizeof
(
Float
);
return
2
*
(
a_block_space
_size
+
b_block_space
_size
)
*
sizeof
(
Float
);
}
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
...
@@ -209,14 +209,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -209,14 +209,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
constexpr
index_t
a_block_space
_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
constexpr
index_t
b_block_space
_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space
_size
;
// register allocation for output
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
...
@@ -230,47 +230,55 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -230,47 +230,55 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
}
constexpr
auto
a_block_slice_copy_steps
=
Sequence
<
KPerBlock
,
0
>
{};
constexpr
auto
a_block_slice_copy_step
=
Sequence
<
KPerBlock
,
0
>
{};
constexpr
auto
b_block_slice_copy_steps
=
Sequence
<
KPerBlock
,
0
>
{};
constexpr
auto
b_block_slice_copy_step
=
Sequence
<
KPerBlock
,
0
>
{};
Float
*
p_a_block_even
=
p_a_block_double
;
Float
*
p_b_block_even
=
p_b_block_double
;
Float
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
Float
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
// LDS double buffer: main body
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
2
*
KPerBlock
;
k_block_data_begin
+=
2
*
KPerBlock
)
k_block_data_begin
+=
2
*
KPerBlock
)
{
{
#pragma unroll
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_a_block_now
=
// even iteration
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
True
);
Float
*
p_b_block_now
=
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
Float
*
p_a_block_next
=
__syncthreads
();
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
// LDS doubel buffer: load next data from device mem
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
);
// LDS double buffer: GEMM on current data
b_
blockwise_
copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
);
blockwise_
gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
p_c_thread
);
__syncthreads
();
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_odd
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_odd
);
// LDS doubel buffer: load next data from device mem
// odd iteration
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_
a_
g
lo
bal
,
p_a_thread_buffer
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_
b
lo
ck_slice_copy_step
,
True
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
// LDS double buffer: GEMM on current data
__syncthreads
();
blockwise_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
}
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_even
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_even
);
}
}
// LDS double buffer: tail
// LDS double buffer: tail
...
@@ -282,8 +290,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -282,8 +290,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
s
,
True
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
s
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
__syncthreads
();
__syncthreads
();
...
@@ -296,15 +304,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -296,15 +304,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
p_a_block_double
+
a_block_space
_size
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
p_b_block_double
+
b_block_space
_size
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_a_block_double
+
a_block_space
,
p_b_block_double
+
b_block_space
,
p_c_thread
);
p_b_block_double
+
b_block_space_size
,
p_c_thread
);
}
}
else
// if has 1 iteration left
else
// if has 1 iteration left
{
{
...
@@ -433,13 +442,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -433,13 +442,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
constexpr
index_t
a_block_space
_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
constexpr
index_t
b_block_space
_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
return
2
*
(
a_block_space
+
b_block_space
)
*
sizeof
(
Float
);
return
2
*
(
a_block_space
_size
+
b_block_space
_size
)
*
sizeof
(
Float
);
}
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
...
@@ -584,14 +593,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -584,14 +593,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
constexpr
index_t
a_block_space
_size
=
math
::
integer_least_multiple
(
a_k0_k1_k2_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k0_k1_k2_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
constexpr
index_t
b_block_space
_size
=
math
::
integer_least_multiple
(
b_k0_k1_k2_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_k0_k1_k2_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space
_size
;
// register allocation for output
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
...
@@ -603,15 +612,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -603,15 +612,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
{
{
for
(
index_t
k1
=
0
;
k1
<
K1
;
++
k1
)
for
(
index_t
k1
=
0
;
k1
<
K1
;
++
k1
)
{
{
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
}
constexpr
auto
a_block_slice_copy_step
s
=
Sequence
<
0
,
0
,
KPerBlock
,
0
>
{};
constexpr
auto
a_block_slice_copy_step
=
Sequence
<
0
,
0
,
KPerBlock
,
0
>
{};
constexpr
auto
b_block_slice_copy_step
s
=
Sequence
<
0
,
0
,
KPerBlock
,
0
>
{};
constexpr
auto
b_block_slice_copy_step
=
Sequence
<
0
,
0
,
KPerBlock
,
0
>
{};
// LDS double buffer: main body
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
...
@@ -623,20 +631,20 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -623,20 +631,20 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
const
bool
even_loop
=
(
iloop
%
2
==
0
);
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_a_block_now
=
Float
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
_size
;
Float
*
p_b_block_now
=
Float
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
_size
;
Float
*
p_a_block_next
=
Float
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
even_loop
?
p_a_block_double
+
a_block_space
_size
:
p_a_block_double
;
Float
*
p_b_block_next
=
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
even_loop
?
p_b_block_double
+
b_block_space
_size
:
p_b_block_double
;
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
s
,
True
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
s
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
__syncthreads
();
__syncthreads
();
...
@@ -662,8 +670,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -662,8 +670,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
s
,
True
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
s
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
__syncthreads
();
__syncthreads
();
...
@@ -675,16 +683,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
...
@@ -675,16 +683,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_block_double
+
a_block_space
);
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
_size
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_block_double
+
b_block_space
);
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
_size
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space
,
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space
_size
,
p_b_block_double
+
b_block_space
,
p_b_block_double
+
b_block_space
_size
,
p_c_thread
);
p_c_thread
);
}
}
else
// if has 1 iteration left
else
// if has 1 iteration left
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment