Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
05d7a087
"test/srt/git@developer.sourcefind.cn:change/sglang.git" did not exist on "caa4819bfcdc1b0e081d2b93500ea3d4d2cb8e00"
Commit
05d7a087
authored
Apr 03, 2019
by
Chao Liu
Browse files
enable 128x128 block gemm
parent
6a3f3f95
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
112 additions
and
125 deletions
+112
-125
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+3
-3
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+7
-7
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
+3
-2
src/include/inline_asm.hpp
src/include/inline_asm.hpp
+93
-107
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+6
-6
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
05d7a087
...
...
@@ -190,7 +190,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
256
;
#elif
1
#elif
0
// 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
...
...
@@ -221,8 +221,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 1x1, 14x14, Vega 20,
hack CPerBlock = 1
constexpr
index_t
BPerBlock
=
64
;
// 1x1, 14x14, Vega 20,
try
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
...
...
src/include/blockwise_gemm.hip.hpp
View file @
05d7a087
...
...
@@ -377,13 +377,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
//auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
//auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
//
auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
//
auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
void
*
a_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
void
*
a_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
// loop over k
int
k_chunk
=
2
;
#pragma unroll
...
...
@@ -403,9 +403,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#else
int
k
=
k_begin
;
int
lds_a_block_off
=
sizeof
(
Float
)
*
M
;
int
lds_b_block_off
=
sizeof
(
Float
)
*
N
;
int
k
=
k_begin
;
int
lds_a_block_off
=
sizeof
(
Float
)
*
M
;
int
lds_b_block_off
=
sizeof
(
Float
)
*
N
;
int
lds_a_block_off_1
=
MPerLevel1Cluster
*
sizeof
(
Float
);
int
lds_b_block_off_1
=
NPerLevel1Cluster
*
sizeof
(
Float
);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k
*
lds_a_block_off
);
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
View file @
05d7a087
...
...
@@ -272,7 +272,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
// LDS: be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
mod_conv
::
max
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
index_t
in_block_element_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
...
@@ -297,7 +297,8 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
// load data
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
...
...
src/include/inline_asm.hpp
View file @
05d7a087
...
...
@@ -4,56 +4,68 @@ typedef float Float4 __attribute__((ext_vector_type(4)));
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
inline
__device__
void
lgkmcnt
(
int
cnt
){
inline
__device__
void
lgkmcnt
(
int
cnt
)
{
#if 1
if
(
cnt
==
0
)
{
if
(
cnt
==
0
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(0)
\n
\
"
::
);
"
::
);
}
else
if
(
cnt
==
1
)
{
else
if
(
cnt
==
1
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(1)
\n
\
"
::
);
"
::
);
}
else
if
(
cnt
==
2
)
{
else
if
(
cnt
==
2
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(2)
\n
\
"
::
);
"
::
);
}
else
if
(
cnt
==
3
)
{
else
if
(
cnt
==
3
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(3)
\n
\
"
::
);
"
::
);
}
else
if
(
cnt
==
4
)
{
else
if
(
cnt
==
4
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(4)
\n
\
"
::
);
"
::
);
}
else
{
else
{
assert
(
0
);
}
#endif
}
inline
__device__
void
outerProduct1x4
(
const
float
*
a
,
const
float
*
b
,
float
*
c
)
{
inline
__device__
void
outerProduct1x4
(
const
float
*
a
,
const
float
*
b
,
float
*
c
)
{
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
"=v"
(
c
[
0
]),
"=v"
(
c
[
1
]),
"=v"
(
c
[
2
]),
"=v"
(
c
[
3
])
:
"v"
(
a
[
0
]),
"v"
(
b
[
0
]),
"v"
(
b
[
1
]),
"v"
(
b
[
2
]),
"v"
(
b
[
3
]),
"0"
(
c
[
0
]),
"1"
(
c
[
1
]),
"2"
(
c
[
2
]),
"3"
(
c
[
3
])
);
:
"=v"
(
c
[
0
]),
"=v"
(
c
[
1
]),
"=v"
(
c
[
2
]),
"=v"
(
c
[
3
])
:
"v"
(
a
[
0
]),
"v"
(
b
[
0
]),
"v"
(
b
[
1
]),
"v"
(
b
[
2
]),
"v"
(
b
[
3
]),
"0"
(
c
[
0
]),
"1"
(
c
[
1
]),
"2"
(
c
[
2
]),
"3"
(
c
[
3
]));
}
inline
__device__
void
outerProduct1x4
(
const
float
&
a
,
const
Float4
&
b
,
Float4
&
c
)
{
inline
__device__
void
outerProduct1x4
(
const
float
&
a
,
const
Float4
&
b
,
Float4
&
c
)
{
#if 0
asm volatile(
"\n \
...
...
@@ -67,12 +79,13 @@ inline __device__ void outerProduct1x4(const float &a, const Float4 &b, Float4 &
"v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
#else
outerProduct1x4
(
&
a
,
(
float
*
)
&
b
,
(
float
*
)
&
c
);
outerProduct1x4
(
&
a
,
(
float
*
)
&
b
,
(
float
*
)
&
c
);
#endif
}
inline
__device__
void
outerProduct4x4
(
const
Float4
&
a
,
const
Float4
&
b
,
Float4
&
c0
,
Float4
&
c1
,
Float4
&
c2
,
Float4
&
c3
)
{
inline
__device__
void
outerProduct4x4
(
const
Float4
&
a
,
const
Float4
&
b
,
Float4
&
c0
,
Float4
&
c1
,
Float4
&
c2
,
Float4
&
c3
)
{
#if 0
asm volatile(
"\n \
...
...
@@ -126,7 +139,7 @@ inline __device__ void outerProduct4x4(const Float4 &a, const Float4 &b, Float4
#endif
}
inline
__device__
void
outerProduct8x8
(
const
Float4
*
a
,
const
Float4
*
b
,
Float4
*
c
)
inline
__device__
void
outerProduct8x8
(
const
Float4
*
a
,
const
Float4
*
b
,
Float4
*
c
)
{
outerProduct4x4
(
a
[
0
],
b
[
0
],
c
[
0
],
c
[
2
],
c
[
4
],
c
[
6
]);
outerProduct4x4
(
a
[
0
],
b
[
1
],
c
[
1
],
c
[
3
],
c
[
5
],
c
[
7
]);
...
...
@@ -134,250 +147,223 @@ inline __device__ void outerProduct8x8(const Float4 *a, const Float4 *b, Float4
outerProduct4x4
(
a
[
1
],
b
[
1
],
c
[
9
],
c
[
11
],
c
[
13
],
c
[
15
]);
}
inline
__device__
void
ds_read_b128
(
Float4
&
r
,
void
*
lds
,
int
offset
=
0
)
inline
__device__
void
ds_read_b128
(
Float4
&
r
,
void
*
lds
,
int
offset
=
0
)
{
if
(
offset
==
0
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:0
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
128
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:128
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
256
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:256
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
384
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:384
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:512
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
640
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:640
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
768
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:768
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
896
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:896
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
1024
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1024
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
1152
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1152
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
1280
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1280
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
1408
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1408
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
1536
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1536
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
1664
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1664
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
1792
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1792
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
1920
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:1920
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
2048
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2048
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
2176
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2176
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
2304
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2304
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
2560
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2560
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
2816
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2816
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
3072
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3072
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
3328
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3328
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
3584
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3584
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
3840
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3840
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
4096
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:4096
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
4352
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:4352
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
))
);
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
{
...
...
src/include/threadwise_gemm.hip.hpp
View file @
05d7a087
...
...
@@ -31,10 +31,10 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
0
);
Float4
*
reg_p
=
(
Float4
*
)
&
p_dst
[
dst_index
];
Float4
*
loc_p
=
(
Float4
*
)
&
p_src
[
src_index
];
Float4
*
reg_p
=
(
Float4
*
)
&
p_dst
[
dst_index
];
Float4
*
loc_p
=
(
Float4
*
)
&
p_src
[
src_index
];
ds_read_b128
(
reg_p
[
0
],
(
void
*
)
&
loc_p
[
0
]);
ds_read_b128
(
reg_p
[
0
],
(
void
*
)
&
loc_p
[
0
]);
}
#endif
}
...
...
@@ -86,9 +86,9 @@ __device__ void threadwise_gemm(MatrixA,
}
}
#else
const
Float4
*
a_vec
=
(
const
Float4
*
)
p_a_thread
;
const
Float4
*
b_vec
=
(
const
Float4
*
)
p_b_thread
;
Float4
*
c_vec
=
(
Float4
*
)
p_c_thread
;
const
Float4
*
a_vec
=
(
const
Float4
*
)
p_a_thread
;
const
Float4
*
b_vec
=
(
const
Float4
*
)
p_b_thread
;
Float4
*
c_vec
=
(
Float4
*
)
p_c_thread
;
outerProduct8x8
(
a_vec
,
b_vec
,
c_vec
);
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment