Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
605afd0f
Commit
605afd0f
authored
Apr 04, 2019
by
Chao Liu
Browse files
Merge branch 'master' into inline_asm_v2
parents
66edb259
6166233e
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
77 additions
and
962 deletions
+77
-962
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+10
-9
driver/driver.hip.cpp
driver/driver.hip.cpp
+1
-1
script/cmake-cuda.sh
script/cmake-cuda.sh
+5
-4
script/cmake-hip.sh
script/cmake-hip.sh
+0
-0
script/compile-hip.sh
script/compile-hip.sh
+6
-0
script/extract_asm-cuda.sh
script/extract_asm-cuda.sh
+1
-0
script/tracer-hip.sh
script/tracer-hip.sh
+3
-0
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+0
-455
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+13
-367
src/include/common.hip.hpp
src/include/common.hip.hpp
+7
-2
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
...dwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
+8
-81
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+11
-14
src/include/gridwise_convolution_wrapper.hip.hpp
src/include/gridwise_convolution_wrapper.hip.hpp
+9
-0
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+3
-29
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
605afd0f
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
InDesc
,
void
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
InDesc
,
...
@@ -221,7 +222,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -221,7 +222,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
#elif 1
// 1x1, 14x14, Vega 20,
try
// 1x1, 14x14, Vega 20,
enable lds_double_buffer, disable register_double_buffer
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
...
@@ -271,10 +272,10 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -271,10 +272,10 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
#if
0
#if
1
g
ridwise
_i
mplicit
_g
emm_
convolution_
2_chwn_cyxk_khwn
G
ridwise
ConvolutionI
mplicit
G
emm_
v
2_chwn_cyxk_khwn
#else
#else
g
ridwise
_i
mplicit
_g
emm_
convolution_
2_chwn_cyxk_khwn_lds_double_buffer
G
ridwise
ConvolutionI
mplicit
G
emm_
v
2_chwn_cyxk_khwn_lds_double_buffer
#endif
#endif
<
GridSize
,
<
GridSize
,
BlockSize
,
BlockSize
,
...
@@ -301,12 +302,12 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -301,12 +302,12 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
()
;
WeiBlockCopyDataPerRead
>
{}
;
float
time
=
launch_kernel
(
gridwise_conv
.
Run
,
float
time
=
launch_kernel
(
run_
gridwise_conv
olution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
gridwise_conv
.
GetDynamicSharedMemoryUsage
()
,
0
,
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
...
...
driver/driver.hip.cpp
View file @
605afd0f
...
@@ -592,7 +592,7 @@ int main(int argc, char* argv[])
...
@@ -592,7 +592,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
#elif
0
// 1x1 filter, 14x14 image, C = 512
// 1x1 filter, 14x14 image, C = 512
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
C
=
512
;
...
...
build
/cmake-cuda.sh
→
script
/cmake-cuda.sh
View file @
605afd0f
...
@@ -15,11 +15,12 @@ cmake
...
@@ -15,11 +15,12 @@ cmake
-D
DEVICE_BACKEND
=
CUDA
\
-D
DEVICE_BACKEND
=
CUDA
\
-D
BOOST_ROOT
=
"/package/install/boost_1.67.0"
\
-D
BOOST_ROOT
=
"/package/install/boost_1.67.0"
\
-D
CUDA_COMMON_INCLUDE_DIR
=
"/home/chao/code/test_feature/cuda_common/cuda_10.0_common/inc"
\
-D
CUDA_COMMON_INCLUDE_DIR
=
"/home/chao/code/test_feature/cuda_common/cuda_10.0_common/inc"
\
-D
CMAKE_CUDA_FLAGS
=
"-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -
arch=sm_61
"
\
-D
CMAKE_CUDA_FLAGS
=
"-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -
gencode=arch=compute_61,code=sm_61 -Xptxas -v -Xptxas -v -maxrregcount=128
"
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
#-D CMAKE_CUDA_COMPILER="/package/install/cuda_10.0/bin/nvcc" \
#-D CMAKE_CUDA_COMPILER="/package/install/cuda_10.0/bin/nvcc" \
#-D CMAKE_CUDA_FLAGS="-G -lineinfo --source-in-ptx -keep -Xptxas -v -arch=sm_61" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -arch=sm_61" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -Xptxas -v -maxrregcount=128" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -arch=sm_61 -Xptxas -v -maxrregcount=128" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70" \
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128" \
build
/cmake-hip.sh
→
script
/cmake-hip.sh
View file @
605afd0f
File moved
script/compile-hip.sh
0 → 100755
View file @
605afd0f
#!/bin/bash
export
KMDUMPISA
=
1
export
KMDUMPLLVM
=
1
make
-j
driver
/opt/rocm/hcc/bin/llvm-objdump
-mcpu
=
gfx906
-source
-line-numbers
driver/dump-gfx906.isabin
>
driver/dump-gfx906.isabin.isa
script/extract_asm-cuda.sh
0 → 100755
View file @
605afd0f
cuobjdump
-xelf
all ./driver/driver
&&
nvdisasm
--print-code
-g
driver.sm_61.cubin
>
driver.sm_61.asm
&&
nvdisasm
--print-code
-g
driver.sm_70.cubin
>
driver.sm_70.asm
script/tracer-hip.sh
0 → 100755
View file @
605afd0f
#!/bin/bash
/root/workspace/rocprofiler_pkg/bin/rpl_run.sh
--timestamp
on
-i
/root/workspace/rocprofiler_pkg/input.xml
-d
./trace ./driver/driver 0 10
src/include/blockwise_batched_gemm.hip.hpp
View file @
605afd0f
#pragma once
#pragma once
#include "threadwise_gemm.hip.hpp"
#include "threadwise_gemm.hip.hpp"
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
bool
TransA
,
bool
TransB
,
bool
TransC
,
index_t
BlockMatrixStrideA
,
index_t
BlockMatrixStrideB
,
index_t
ThreadMatrixStrideC
,
index_t
BatchSize
,
index_t
BatchPerThread
,
index_t
KPerThreadLoop
,
bool
DistributeThreadAlongColumnFirst
>
struct
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
{
index_t
mMyThreadOffsetA
=
0
;
index_t
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
index_t
batch
;
index_t
row
;
index_t
col
;
};
__device__
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
()
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
const
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideA
+
((
!
TransA
)
?
a_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
row
,
0
)
:
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
((
!
TransB
)
?
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
)
:
b_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
col
,
0
));
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! k dimension not consistent!"
);
constexpr
index_t
MPerBlock
=
a_block_mtx
.
NCol
();
constexpr
index_t
NPerBlock
=
b_block_mtx
.
NCol
();
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
// divide thread work
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"BatchSize % BatchPerThread != 0"
);
static_assert
(
MPerBlock
%
MPerThread
==
0
,
"MPerBlock % MPerThread != 0"
);
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"NPerBlock % NPerThread != 0"
);
constexpr
index_t
BatchThreadWork
=
(
BatchSize
+
BatchPerThread
-
1
)
/
BatchPerThread
;
constexpr
index_t
MThreadWork
=
(
MPerBlock
+
MPerThread
-
1
)
/
MPerThread
;
constexpr
index_t
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
static_assert
(
BlockSize
==
BatchThreadWork
*
MThreadWork
*
NThreadWork
,
"wrong! wrong BlockSize"
);
if
(
DistributeThreadAlongColumnFirst
)
{
// num of operations can be reduced
const
index_t
b_work_id
=
thread_id
/
(
MThreadWork
*
NThreadWork
);
index_t
itmp
=
thread_id
-
b_work_id
*
(
MThreadWork
*
NThreadWork
);
const
index_t
m_work_id
=
itmp
/
NThreadWork
;
const
index_t
n_work_id
=
itmp
-
m_work_id
*
NThreadWork
;
return
MatrixIndex
{
b_work_id
*
BatchPerThread
,
m_work_id
*
MPerThread
,
n_work_id
*
NPerThread
};
}
else
{
// not implemented
assert
(
false
);
}
}
else
{
// not implemented
assert
(
false
);
}
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
index_t
batch_in_c
,
index_t
m_in_c
,
index_t
n_in_c
)
{
return
MatrixIndex
{
batch_in_c
,
m_in_c
,
n_in_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// a is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of a, b
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
k_begin
*
a_block_mtx
.
RowStride
(),
a_thread_mtx
,
p_a_thread
,
a_thread_mtx
.
GetLengths
());
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
k_begin
*
b_block_mtx
.
RowStride
(),
b_thread_mtx
,
p_b_thread
,
b_thread_mtx
.
GetLengths
());
// loop over batch
for
(
index_t
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
,
f_accum
);
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
+
k_begin
*
a_block_mtx
.
RowStride
(),
a_thread_mtx
,
p_a_thread
,
a_thread_mtx
.
GetLengths
());
}
if
(
BlockMatrixStrideB
!=
0
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
k_begin
*
b_block_mtx
.
RowStride
(),
b_thread_mtx
,
p_b_thread
,
b_thread_mtx
.
GetLengths
());
}
}
// do last batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
,
f_accum
);
}
}
}
};
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
BlockMatrixB
,
...
@@ -526,236 +301,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -526,236 +301,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_v3
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
//#pragma unroll
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
index_t
i
=
0
;
i
<
a_thread_sub_mtx
.
NRow
();
++
i
)
{
#if 1
for
(
index_t
j
=
0
;
j
<
a_thread_sub_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetA
];
}
#else
static_assert
(
a_thread_sub_mtx
.
NCol
()
==
4
,
"asm only read 4xfp32"
);
#endif
}
}
// copy B-sub to form B
//#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
index_t
i
=
0
;
i
<
b_thread_sub_mtx
.
NRow
();
++
i
)
{
for
(
index_t
j
=
0
;
j
<
b_thread_sub_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetB
];
}
}
}
// loop over batch
//#pragma unroll
for
(
index_t
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
for
(
index_t
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{
const index_t aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const index_t cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif
1
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
index_t
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
index_t
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
index_t
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]));
}
#endif
}
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
//#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
index_t
i
=
0
;
i
<
a_thread_sub_mtx
.
NRow
();
++
i
)
{
for
(
index_t
j
=
0
;
j
<
a_thread_sub_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
];
}
}
}
}
if
(
BlockMatrixStrideB
!=
0
)
{
//#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
index_t
i
=
0
;
i
<
b_thread_sub_mtx
.
NRow
();
++
i
)
{
for
(
index_t
j
=
0
;
j
<
b_thread_sub_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
];
}
}
}
}
}
// do last batch of gemm
for
(
index_t
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const index_t cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif
1
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
index_t
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
index_t
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
index_t
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
)
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
;
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]));
}
#endif
}
}
}
template
<
class
BlockMatrixC
,
index_t
BlockMatrixStrideC
,
class
FloatC
>
template
<
class
BlockMatrixC
,
index_t
BlockMatrixStrideC
,
class
FloatC
>
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_block
)
const
FloatC
*
__restrict__
p_c_block
)
const
...
...
src/include/blockwise_gemm.hip.hpp
View file @
605afd0f
#pragma once
#pragma once
#include "threadwise_gemm.hip.hpp"
#include "threadwise_gemm.hip.hpp"
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
bool
TransA
,
bool
TransB
,
bool
TransC
,
index_t
KPerThreadLoop
,
index_t
MThreadPerCluster
,
index_t
NThreadPerCluster
,
bool
DistributeThreadAlongColumnFirst
>
struct
BlockwiseGemmBlockABlockBThreadC
{
index_t
mMyThreadOffsetA
=
0
;
index_t
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
__device__
BlockwiseGemmBlockABlockBThreadC
()
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
const
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
(
!
TransA
)
?
a_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
row
,
0
)
:
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetB
=
(
!
TransB
)
?
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
)
:
b_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
col
,
0
);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! k dimension not consistent!"
);
constexpr
index_t
MPerBlock
=
a_block_mtx
.
NCol
();
constexpr
index_t
NPerBlock
=
b_block_mtx
.
NCol
();
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
// divide thread work
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
(
MPerBlock
%
(
MPerThread
*
MThreadPerCluster
)
==
0
,
"MPerBlock % (MPerThread * MThreadPerCluster) != 0"
);
static_assert
(
NPerBlock
%
(
NPerThread
*
NThreadPerCluster
)
==
0
,
"NPerBlock % (NPerThread * NThreadPerCluster) != 0"
);
constexpr
index_t
MClusterWork
=
(
MPerBlock
+
MPerThread
*
MThreadPerCluster
-
1
)
/
(
MPerThread
*
MThreadPerCluster
);
constexpr
index_t
NClusterWork
=
(
NPerBlock
+
NPerThread
*
NThreadPerCluster
-
1
)
/
(
NPerThread
*
NThreadPerCluster
);
static_assert
(
BlockSize
==
(
MClusterWork
*
MThreadPerCluster
)
*
(
NClusterWork
*
NThreadPerCluster
),
"wrong! wrong BlockSize"
);
if
(
DistributeThreadAlongColumnFirst
)
{
const
index_t
cluster_work_block_id
=
thread_id
/
(
MThreadPerCluster
*
NThreadPerCluster
);
const
index_t
thread_work_cluster_id
=
thread_id
-
cluster_work_block_id
*
(
MThreadPerCluster
*
NThreadPerCluster
);
const
index_t
m_cluster_work_block_id
=
cluster_work_block_id
/
NClusterWork
;
const
index_t
n_cluster_work_block_id
=
cluster_work_block_id
-
m_cluster_work_block_id
*
NClusterWork
;
const
index_t
m_thread_work_cluster_id
=
thread_work_cluster_id
/
NThreadPerCluster
;
const
index_t
n_thread_work_cluster_id
=
thread_work_cluster_id
-
m_thread_work_cluster_id
*
NThreadPerCluster
;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, \t"
"MClusterWork %u MThreadPerCluster %u NClusterWork %u NThreadPerCluster %u \t"
"m_cluster_work_block_id %u n_cluster_work_block_id %u \t"
"m_thread_work_cluster_id %u n_thread_work_cluster_id %u \t"
"\n",
get_block_1d_id(), get_thread_local_1d_id(),
MClusterWork, MThreadPerCluster, NClusterWork, NThreadPerCluster,
m_cluster_work_block_id, n_cluster_work_block_id,
m_thread_work_cluster_id, n_thread_work_cluster_id);
}
#endif
return
MatrixIndex
{
m_cluster_work_block_id
*
(
MThreadPerCluster
*
MPerThread
)
+
m_thread_work_cluster_id
*
MPerThread
,
n_cluster_work_block_id
*
(
NThreadPerCluster
*
NPerThread
)
+
n_thread_work_cluster_id
*
NPerThread
};
}
else
{
// not implemented
assert
(
false
);
}
}
else
{
// not implemented
assert
(
false
);
}
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
index_t
m_in_c
,
index_t
n_in_c
)
{
return
MatrixIndex
{
m_in_c
,
n_in_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// a is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
k_begin
*
a_block_mtx
.
RowStride
(),
a_thread_mtx
,
p_a_thread
,
a_thread_mtx
.
GetLengths
());
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
k_begin
*
b_block_mtx
.
RowStride
(),
b_thread_mtx
,
p_b_thread
,
b_thread_mtx
.
GetLengths
());
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
,
f_accum
);
}
}
}
};
// if following number are power of 2, index calculation shall be greatly reduced:
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
...
@@ -332,11 +123,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -332,11 +123,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
#if DEVICE_BACKEND_HIP
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_thread
)
const
Accumulator
f_accum
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
@@ -414,12 +205,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -414,12 +205,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
}
#endif
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
const
__restrict__
p_a_block
,
__device__
void
Run
(
const
FloatA
*
const
__restrict__
p_a_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
FloatC
*
const
__restrict__
p_c_thread
,
FloatC
*
const
__restrict__
p_c_thread
)
const
Accumulator
f_accum
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
@@ -499,16 +290,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -499,16 +290,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_thread
,
p_b_thread
,
c_thread_mtx
,
c_thread_mtx
,
False
,
False
,
p_c_thread
,
p_c_thread
);
f_accum
);
}
}
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_RegisterDoubleBuffer
(
FloatA
*
const
p_a_block
,
__device__
void
Run_RegisterDoubleBuffer
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatB
*
const
p_b_block
,
FloatC
*
p_c_thread
,
FloatC
*
p_c_thread
)
const
Accumulator
f_accum
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
@@ -618,8 +407,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -618,8 +407,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_thread_now
,
p_b_thread_now
,
c_thread_mtx
,
c_thread_mtx
,
False
,
False
,
p_c_thread
,
p_c_thread
);
f_accum
);
}
}
// last loop
// last loop
...
@@ -636,149 +424,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -636,149 +424,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_thread_now
,
p_b_thread_now
,
c_thread_mtx
,
c_thread_mtx
,
False
,
False
,
p_c_thread
,
p_c_thread
);
f_accum
);
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_v2
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A-sub, B-sub, C-sub
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
constexpr
auto
c_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
// thread A, B
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
// C-sub(s) in first row-wise subblock of C
{
// copy first A-sub
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
,
a_thread_sub_mtx
.
GetLengths
());
// copy first B-sub
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
,
b_thread_sub_mtx
.
GetLengths
());
// do first sub GEMM
threadwise_gemm
(
a_thread_sub_mtx
,
True
,
p_a_thread
,
b_thread_sub_mtx
,
False
,
p_b_thread
,
c_thread_sub_mtx
,
False
,
p_c_thread
,
f_accum
);
#pragma unroll
// copy next B-sub, and do GEMM
for
(
index_t
n_repeat
=
1
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
threadwise_gemm
(
a_thread_sub_mtx
,
True
,
p_a_thread
,
b_thread_sub_mtx
,
False
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
c_thread_sub_mtx
,
False
,
p_c_thread
+
c_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
f_accum
);
}
#pragma unroll
// loop over rest of row-wise subblock
// all B-sub(s) has been copied, so only A-sub(s) need to be copied
for
(
index_t
m_repeat
=
1
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy a A-sub
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
// do some GEMMs
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_gemm
(
a_thread_sub_mtx
,
True
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
b_thread_sub_mtx
,
False
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
c_thread_sub_mtx
,
False
,
p_c_thread
+
c_thread_mtx
.
Get1dIndex
(
m_repeat
*
MPerThreadSubC
,
n_repeat
*
NPerThreadSubC
),
f_accum
);
}
}
}
}
}
}
}
};
};
src/include/common.hip.hpp
View file @
605afd0f
...
@@ -26,7 +26,7 @@ __host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b)
...
@@ -26,7 +26,7 @@ __host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b)
return
(
a
+
b
-
1
)
/
b
;
return
(
a
+
b
-
1
)
/
b
;
}
}
namespace
mod_conv
{
namespace
mod_conv
{
// namespace mod_conv
template
<
class
T
>
template
<
class
T
>
__host__
__device__
constexpr
T
max
(
T
x
,
T
y
)
__host__
__device__
constexpr
T
max
(
T
x
,
T
y
)
{
{
...
@@ -62,4 +62,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
...
@@ -62,4 +62,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
return
x
<
y
?
x
:
y
;
return
x
<
y
?
x
:
y
;
}
}
}
}
// namespace mod_conv
#if DEVICE_BACKEND_HIP
// cast a pointer of LDS to its address
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
#endif
src/include/gridwise_implicit_gemm_
convolution_
2_chwn_cyxk_khwn.hip.hpp
→
src/include/gridwise_
convolution_
implicit_gemm_
v
2_chwn_cyxk_khwn.hip.hpp
View file @
605afd0f
...
@@ -34,84 +34,13 @@ template <index_t GridSize,
...
@@ -34,84 +34,13 @@ template <index_t GridSize,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
>
index_t
WeiBlockCopyDataPerRead
>
struct
g
ridwise
_i
mplicit
_g
emm_
convolution_
2_chwn_cyxk_khwn
struct
G
ridwise
ConvolutionI
mplicit
G
emm_
v
2_chwn_cyxk_khwn
{
{
__host__
__device__
constexpr
index_t
GetInputBlockElementSpace
()
const
__host__
__device__
constexpr
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
()
{}
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// tensor view of blockwise input
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
// LDS: be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
return
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
}
__host__
__device__
constexpr
index_t
GetWeightBlockElementSpace
()
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// tensor view of blockwise weight
// be careful of alignment
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// LDS: be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
return
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
}
__host__
__device__
constexpr
index_t
GetDynamicSharedMemoryUsage
()
const
{
return
(
GetInputBlockElementSpace
()
+
GetWeightBlockElementSpace
())
*
sizeof
(
Float
);
}
__device__
constexpr
static
Float
*
GetSharedMemoryBegin
()
{
extern
__shared__
Float
s
[];
return
s
;
}
__
global__
static
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__
device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
Float
*
const
__restrict__
p_out_global
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -279,8 +208,8 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
...
@@ -279,8 +208,8 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
constexpr
index_t
wei_block_element_space
=
constexpr
index_t
wei_block_element_space
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
Float
*
const
p_in_block
=
GetSharedMemoryBegin
()
;
__shared__
Float
p_in_block
[
in_block_element_space
]
;
Float
*
const
p_wei_block
=
GetSharedMemoryBegin
()
+
in
_block_element_space
;
__shared__
Float
p_wei_block
[
wei
_block_element_space
]
;
const
Float
*
p_in_global_block_offset
=
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
0
,
b_block_data_begin
);
...
@@ -340,7 +269,6 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
...
@@ -340,7 +269,6 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
{
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if 0
#if 0
blockwise_gemm.Run
blockwise_gemm.Run
#elif
0
#elif
0
...
@@ -350,8 +278,7 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
...
@@ -350,8 +278,7 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
y
*
Wi
+
x
,
p_in_block
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
);
f_accum
);
}
}
}
}
}
}
...
...
src/include/gridwise_implicit_gemm_
convolution_
2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
→
src/include/gridwise_
convolution_
implicit_gemm_
v
2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
605afd0f
...
@@ -34,13 +34,16 @@ template <index_t GridSize,
...
@@ -34,13 +34,16 @@ template <index_t GridSize,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
>
index_t
WeiBlockCopyDataPerRead
>
struct
g
ridwise
_i
mplicit
_g
emm_
convolution_
2_chwn_cyxk_khwn_lds_double_buffer
struct
G
ridwise
ConvolutionI
mplicit
G
emm_
v
2_chwn_cyxk_khwn_lds_double_buffer
{
{
__host__
__device__
constexpr
index_t
GetDynamicSharedMemoryUsage
()
const
{
return
0
;
}
__host__
__device__
constexpr
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
()
{
}
__
global__
static
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__
device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
Float
*
const
__restrict__
p_out_global
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -312,7 +315,6 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -312,7 +315,6 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
{
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if 0
#if 0
blockwise_gemm.Run
blockwise_gemm.Run
#elif
0
#elif
0
...
@@ -322,8 +324,7 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -322,8 +324,7 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
#endif
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
);
f_accum
);
}
}
}
}
...
@@ -366,7 +367,6 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -366,7 +367,6 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
{
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if 0
#if 0
blockwise_gemm.Run
blockwise_gemm.Run
#elif
1
#elif
1
...
@@ -376,8 +376,7 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -376,8 +376,7 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
#endif
#endif
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
y
*
Wi
+
x
,
p_in_block_double
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
);
f_accum
);
}
}
}
}
...
@@ -397,7 +396,6 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -397,7 +396,6 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
{
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if 0
#if 0
blockwise_gemm.Run
blockwise_gemm.Run
#elif
1
#elif
1
...
@@ -408,8 +406,7 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -408,8 +406,7 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
(
p_wei_block_double
+
in_block_element_space
+
(
p_wei_block_double
+
in_block_element_space
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
wei_block_element_space
+
y
*
Wi
+
x
,
p_in_block_double
+
wei_block_element_space
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
);
f_accum
);
}
}
}
}
}
}
...
...
src/include/gridwise_convolution_wrapper.hip.hpp
0 → 100644
View file @
605afd0f
#pragma once
template
<
class
GridwiseConvolution
,
class
T
>
__global__
void
run_gridwise_convolution
(
const
T
*
const
__restrict__
p_in_global
,
const
T
*
const
__restrict__
p_wei_global
,
T
*
const
__restrict__
p_out_global
)
{
GridwiseConvolution
{}.
Run
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
src/include/threadwise_gemm.hip.hpp
View file @
605afd0f
...
@@ -12,7 +12,6 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -12,7 +12,6 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
#if 1
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
...
@@ -23,20 +22,6 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -23,20 +22,6 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
p_dst
[
dst_index
]
=
p_src
[
src_index
];
p_dst
[
dst_index
]
=
p_src
[
src_index
];
}
}
}
}
#else
static_assert
(
NCol
==
4
,
"only for NCol == 4"
);
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
0
);
Float4
*
reg_p
=
(
Float4
*
)
&
p_dst
[
dst_index
];
Float4
*
loc_p
=
(
Float4
*
)
&
p_src
[
src_index
];
ds_read_b128
(
reg_p
[
0
],
(
void
*
)
&
loc_p
[
0
]);
}
#endif
}
}
template
<
class
MatrixA
,
template
<
class
MatrixA
,
...
@@ -47,8 +32,7 @@ template <class MatrixA,
...
@@ -47,8 +32,7 @@ template <class MatrixA,
bool
TransC
,
bool
TransC
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
,
class
FloatC
>
class
Accumulator
>
__device__
void
threadwise_gemm
(
MatrixA
,
__device__
void
threadwise_gemm
(
MatrixA
,
integral_constant
<
bool
,
TransA
>
,
integral_constant
<
bool
,
TransA
>
,
const
FloatA
*
__restrict__
p_a_thread
,
const
FloatA
*
__restrict__
p_a_thread
,
...
@@ -57,8 +41,7 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -57,8 +41,7 @@ __device__ void threadwise_gemm(MatrixA,
const
FloatB
*
__restrict__
p_b_thread
,
const
FloatB
*
__restrict__
p_b_thread
,
MatrixC
,
MatrixC
,
integral_constant
<
bool
,
TransC
>
,
integral_constant
<
bool
,
TransC
>
,
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_thread
)
Accumulator
f_accum
)
{
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
{
...
@@ -72,26 +55,17 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -72,26 +55,17 @@ __device__ void threadwise_gemm(MatrixA,
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
{
#if 1
for
(
index_t
i
=
0
;
i
<
M
;
i
++
)
for
(
index_t
i
=
0
;
i
<
M
;
i
++
)
{
{
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
for
(
index_t
j
=
0
;
j
<
N
;
j
++
)
for
(
index_t
j
=
0
;
j
<
N
;
j
++
)
{
{
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
bindex
=
b_mtx
.
Get1dIndex
(
k
,
j
);
const
index_t
bindex
=
b_mtx
.
Get1dIndex
(
k
,
j
);
const
index_t
cindex
=
c_mtx
.
Get1dIndex
(
i
,
j
);
const
index_t
cindex
=
c_mtx
.
Get1dIndex
(
i
,
j
);
p_c_thread
[
cindex
]
+=
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
];
p_c_thread
[
cindex
]
+=
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
];
}
}
}
}
#else
const
Float4
*
a_vec
=
(
const
Float4
*
)
p_a_thread
;
const
Float4
*
b_vec
=
(
const
Float4
*
)
p_b_thread
;
Float4
*
c_vec
=
(
Float4
*
)
p_c_thread
;
outerProduct8x8
(
a_vec
,
b_vec
,
c_vec
);
#endif
}
}
}
}
else
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment