Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
e624df92
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f848febacdc54c351ed0ed23fcc4c9349828021e"
Commit
e624df92
authored
Apr 09, 2019
by
Chao Liu
Browse files
enabled ds_read_b128 and ds_write_b128 on hip c++
parent
471830a0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
86 deletions
+49
-86
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+2
-2
driver/driver.hip.cpp
driver/driver.hip.cpp
+4
-4
script/compile-hip.sh
script/compile-hip.sh
+2
-1
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+25
-37
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+16
-42
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
e624df92
...
@@ -189,7 +189,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -189,7 +189,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
#elif
1
#elif
0
// 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer
// 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
...
@@ -219,7 +219,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -219,7 +219,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif
0
#elif
1
// 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer
// 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
...
...
driver/driver.hip.cpp
View file @
e624df92
...
@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
...
@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
constexpr index_t WPad = 0;
#elif
1
#elif
0
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -583,7 +583,7 @@ int main(int argc, char* argv[])
...
@@ -583,7 +583,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
0
#elif
1
// 1x1 filter, 14x14 image, C = 2048
// 1x1 filter, 14x14 image, C = 2048
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
2048
;
constexpr
index_t
C
=
2048
;
...
@@ -667,9 +667,9 @@ int main(int argc, char* argv[])
...
@@ -667,9 +667,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 0
#elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
...
...
script/compile-hip.sh
View file @
e624df92
#!/bin/bash
#!/bin/bash
export
KMDUMPISA
=
1
export
KMDUMPISA
=
1
export
KMDUMPLLVM
=
1
export
KMDUMPLLVM
=
1
export
KMOPTLLC
=
-mattr
=
+enable-ds128
make
-j
driver
make
-j
driver
/opt/rocm/hcc/bin/llvm-objdump
-mcpu
=
gfx906
-source
-line-numbers
driver/dump-gfx906.isabin
>
driver/dump-gfx906.isabin.
isa
/opt/rocm/hcc/bin/llvm-objdump
-mcpu
=
gfx906
-source
-line-numbers
driver/dump-gfx906.isabin
>
driver/dump-gfx906.isabin.
asm
src/include/blockwise_gemm.hip.hpp
View file @
e624df92
...
@@ -132,10 +132,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -132,10 +132,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
FloatC
*
__restrict__
p_c_thread
)
const
{
{
static_assert
(
is_same
<
FloatA
,
float
>::
value
&&
is_same
<
FloatB
,
float
>::
value
&&
is_same
<
FloatC
,
float
>::
value
,
"Run_asm only deal with float
\n
"
);
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
@@ -164,56 +160,48 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -164,56 +160,48 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>::
value
&&
is_same
<
FloatB
,
float
>::
value
&&
is_same
<
FloatC
,
float
>::
value
,
"Run_asm only deal with float
\n
"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_asm cannot deal with this GEMM shape yet
\n
"
);
"Run_asm cannot deal with this GEMM shape yet
\n
"
);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
float
p_thread
[
a_thread_mtx
.
GetElementSpace
()
+
b_thread_mtx
.
GetElementSpace
()];
FloatA
*
p_a_thread
=
p_thread
;
FloatB
*
p_b_thread
=
p_thread
+
a_thread_mtx
.
GetElementSpace
();
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
void
*
a_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
int
lds_a_block_off
=
sizeof
(
Float
)
*
M
;
reg_b
[
1
]
=
int
lds_b_block_off
=
sizeof
(
Float
)
*
N
;
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
int
lds_a_block_off_1
=
MPerLevel1Cluster
*
sizeof
(
Float
);
reg_a
[
1
]
=
int
lds_b_block_off_1
=
NPerLevel1Cluster
*
sizeof
(
Float
);
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
0
);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
0
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
lgkmcnt
(
0
);
#pragma unroll
#pragma unroll
for
(
int
k
_i
=
1
;
k
_i
<
K
;
k_i
++
)
for
(
in
dex_
t
k
=
1
;
k
<
K
;
++
k
)
{
{
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k_i
*
lds_a_block_off
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k_i
*
lds_b_block_off
);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k_i
*
lds_b_block_off
);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k_i
*
lds_a_block_off
);
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
lgkmcnt
(
2
);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
lgkmcnt
(
0
);
}
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
...
...
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
e624df92
...
@@ -213,17 +213,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -213,17 +213,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
#if 1
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
p_wei_block_double
);
#else
vmcnt
(
0
);
blockwise_in_copy
.
RunStoreRegisterClipboard_asm
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard_asm
(
p_wei_register_clipboard
,
p_wei_block_double
);
#endif
}
}
// register
// register
...
@@ -261,7 +253,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -261,7 +253,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
...
@@ -271,31 +262,23 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -271,31 +262,23 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
#if
1
#if
0
blockwise_gemm.Run
blockwise_gemm.Run
#elif
0
#elif
0
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif
0
#elif
1
blockwise_gemm
.
Run_asm
blockwise_gemm
.
Run_asm
#endif
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
);
p_out_thread
);
}
}
}
}
#if 1
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
p_wei_block_next
);
#else
vmcnt
(
0
);
blockwise_in_copy
.
RunStoreRegisterClipboard_asm
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard_asm
(
p_wei_register_clipboard
,
p_wei_block_next
);
#endif
}
}
}
}
...
@@ -320,32 +303,23 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -320,32 +303,23 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
#if
1
#if
0
blockwise_gemm.Run
blockwise_gemm.Run
#elif
0
#elif
0
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif
0
#elif
1
blockwise_gemm
.
Run_asm
blockwise_gemm
.
Run_asm
#endif
#endif
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
y
*
Wi
+
x
,
p_in_block_double
+
y
*
Wi
+
x
,
p_out_thread
);
p_out_thread
);
}
}
}
}
#if 1
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
#else
vmcnt
(
0
);
blockwise_in_copy
.
RunStoreRegisterClipboard_asm
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard_asm
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
#endif
// odd
// odd
__syncthreads
();
__syncthreads
();
...
@@ -354,17 +328,17 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -354,17 +328,17 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
#if
1
#if
0
blockwise_gemm.Run
blockwise_gemm.Run
#elif
0
#elif
0
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif
0
#elif
1
blockwise_gemm
.
Run_asm
blockwise_gemm
.
Run_asm
#endif
#endif
(
p_wei_block_double
+
wei_block_space
+
(
p_wei_block_double
+
wei_block_space
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
in_block_space
+
y
*
Wi
+
x
,
p_in_block_double
+
in_block_space
+
y
*
Wi
+
x
,
p_out_thread
);
p_out_thread
);
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment