Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
114fdb58
Commit
114fdb58
authored
Apr 01, 2019
by
Jing Zhang
Browse files
4x4
parent
85c1ff1c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
248 additions
and
275 deletions
+248
-275
driver/driver.hip.cpp
driver/driver.hip.cpp
+1
-1
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+246
-273
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
+1
-1
No files found.
driver/driver.hip.cpp
View file @
114fdb58
...
...
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
#elif
0
// 1x1 filter, 14x14 image, C = 2048
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
2048
;
...
...
src/include/blockwise_gemm.hip.hpp
View file @
114fdb58
...
...
@@ -3,6 +3,179 @@
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
inline
__device__
void
outerProduct4x4
(
float4
&
a
,
float4
&
b
,
float4
&
c0
,
float4
&
c1
,
float4
&
c2
,
float4
&
c3
)
{
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
:
"v"
(
c0
.
x
),
"v"
(
c0
.
y
),
"v"
(
c0
.
z
),
"v"
(
c0
.
w
),
\
"v"
(
a
.
x
),
"v"
(
b
.
x
),
"v"
(
b
.
y
),
"v"
(
b
.
z
),
"v"
(
b
.
w
)
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
:
"v"
(
c1
.
x
),
"v"
(
c1
.
y
),
"v"
(
c1
.
z
),
"v"
(
c1
.
w
),
\
"v"
(
a
.
y
),
"v"
(
b
.
x
),
"v"
(
b
.
y
),
"v"
(
b
.
z
),
"v"
(
b
.
w
)
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
:
"v"
(
c2
.
x
),
"v"
(
c2
.
y
),
"v"
(
c2
.
z
),
"v"
(
c2
.
w
),
\
"v"
(
a
.
z
),
"v"
(
b
.
x
),
"v"
(
b
.
y
),
"v"
(
b
.
z
),
"v"
(
b
.
w
)
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
:
"v"
(
c3
.
x
),
"v"
(
c3
.
y
),
"v"
(
c3
.
z
),
"v"
(
c3
.
w
),
\
"v"
(
a
.
w
),
"v"
(
b
.
x
),
"v"
(
b
.
y
),
"v"
(
b
.
z
),
"v"
(
b
.
w
)
);
}
template
<
uint32_t
cnt
>
inline
__device__
void
lgkmcnt
(){
if
(
cnt
==
0
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(0)
\n
\
"
::
);
}
if
(
cnt
==
1
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(1)
\n
\
"
::
);
}
if
(
cnt
==
2
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(2)
\n
\
"
::
);
}
if
(
cnt
==
3
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(3)
\n
\
"
::
);
}
if
(
cnt
==
4
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(4)
\n
\
"
::
);
}
if
(
cnt
==
5
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(5)
\n
\
"
::
);
}
if
(
cnt
==
6
)
{
asm
volatile
(
"
\n
\
s_waitcnt lgkmcnt(6)
\n
\
"
::
);
}
}
template
<
uint32_t
off
>
inline
__device__
void
shared_read_b128
(
float4
&
a0
,
float4
&
a1
,
float4
&
b0
,
float4
&
b1
,
uint32_t
&
ldsA
,
uint32_t
&
ldsB
)
{
if
(
off
==
0
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:0
\n
\
ds_read_b128 %1, %4 offset:256
\n
\
ds_read_b128 %2, %5 offset:0
\n
\
ds_read_b128 %3, %5 offset:256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
1
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:1*512
\n
\
ds_read_b128 %1, %4 offset:1*512+256
\n
\
ds_read_b128 %2, %5 offset:1*512
\n
\
ds_read_b128 %3, %5 offset:1*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
2
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:2*512
\n
\
ds_read_b128 %1, %4 offset:2*512+256
\n
\
ds_read_b128 %2, %5 offset:2*512
\n
\
ds_read_b128 %3, %5 offset:2*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
3
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:3*512
\n
\
ds_read_b128 %1, %4 offset:3*512+256
\n
\
ds_read_b128 %2, %5 offset:3*512
\n
\
ds_read_b128 %3, %5 offset:3*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
4
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:4*512
\n
\
ds_read_b128 %1, %4 offset:4*512+256
\n
\
ds_read_b128 %2, %5 offset:4*512
\n
\
ds_read_b128 %3, %5 offset:4*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
5
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:5*512
\n
\
ds_read_b128 %1, %4 offset:5*512+256
\n
\
ds_read_b128 %2, %5 offset:5*512
\n
\
ds_read_b128 %3, %5 offset:5*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
6
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:6*512
\n
\
ds_read_b128 %1, %4 offset:6*512+256
\n
\
ds_read_b128 %2, %5 offset:6*512
\n
\
ds_read_b128 %3, %5 offset:6*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
if
(
off
==
7
*
512
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4 offset:7*512
\n
\
ds_read_b128 %1, %4 offset:7*512+256
\n
\
ds_read_b128 %2, %5 offset:7*512
\n
\
ds_read_b128 %3, %5 offset:7*512+256
\n
\
"
:
"=v"
(
a0
),
"=v"
(
a1
),
"=v"
(
b0
),
"=v"
(
b1
)
:
"v"
(
ldsA
),
"v"
(
ldsB
));
}
}
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
...
...
@@ -334,11 +507,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
,
index_t
block_off
>
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
Accumulator
f_accum
,
Number
<
block_off
>
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
...
@@ -387,52 +561,62 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
const
float4
*
a_loc
=
(
const
float4
*
)(
p_a_block
+
a_src_index
);
const
float4
*
b_loc
=
(
const
float4
*
)(
p_b_block
+
b_src_index
);
uint32_t
a_loc
=
block_off
+
a_src_index
;
uint32_t
b_loc
=
b_src_index
;
//const float4* a_loc = (const float4*)(p_b_block + block_off + a_src_index);
//const float4* b_loc = (const float4*)(p_b_block + b_src_index);
float4
*
reg
=
(
float4
*
)(
p_thread
);
float4
*
c_v
=
(
float4
*
)(
p_c_thread
);
reg
[
0
]
=
a_loc
[
0
];
reg
[
1
]
=
a_loc
[
16
];
reg
[
2
]
=
b_loc
[
0
];
reg
[
3
]
=
b_loc
[
8
];
//shared_read_b128<0>(reg[0], reg[1], reg[2], reg[3], a_loc, b_loc);
//reg[0] = a_loc[0];
//reg[1] = a_loc[16];
//reg[2] = b_loc[0];
//reg[3] = b_loc[8];
//asm volatile("\n \
//ds_read
2
_b
64
%0, %1
offset1:1
\n \
//
s_waitcnt lgkmcnt(0)
"
//: "=v"(reg[0])
//: "v"(__to_local((void *)
(a_loc)
))
//);
//ds_read_b
128
%0, %1 \n \
//"
//: "=v"(reg[0])
//: "v"
(a_loc)
//);
//asm volatile("\n \
//ds_read
2
_b
64
%0, %1
offset1:1
\n \
//
s_waitcnt lgkmcnt(0)
"
//: "=v"(reg[1])
//: "v"(__to_local((void *)
(a_loc +
16))
)
//);
//ds_read_b
128
%0, %1 \n \
//"
//: "=v"(reg[1])
//: "v"
(a_loc +
256
)
//);
//asm volatile("\n \
//ds_read
2
_b
64
%0, %1
offset1:1
\n \
//
s_waitcnt lgkmcnt(0)
"
//: "=v"(reg[2])
//: "v"(__to_local((void *)
(b_loc)
))
//);
//ds_read_b
128
%0, %1 \n \
//"
//: "=v"(reg[2])
//: "v"
(b_loc)
//);
//asm volatile("\n \
//ds_read
2
_b
64
%0, %1
offset1:1
\n \
//
s_waitcnt lgkmcnt(0)
"
//: "=v"(reg[3])
//: "v"(__to_local((void *)
(b_loc +
8))
)
//);
//ds_read_b
128
%0, %1\n \
//"
//: "=v"(reg[3])
//: "v"
(b_loc +
128
)
//);
//asm volatile("\n \
//ds_read2_b64 %0, %4 offset1:1 \n \
//ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \
//ds_read2_b64 %2, %5 offset1:1 \n \
//ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3])
//: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc)))
//);
asm
volatile
(
"
\n
\
ds_read_b128 %0, %4
\n
\
ds_read_b128 %1, %4 offset:16
\n
\
ds_read_b128 %2, %5
\n
\
ds_read_b128 %3, %5 offset:8
\n
\
"
:
"=v"
(
reg
[
0
]),
"=v"
(
reg
[
1
]),
"=v"
(
reg
[
2
]),
"=v"
(
reg
[
3
])
:
"v"
(
a_loc
),
"v"
(
b_loc
)
);
lgkmcnt
<
0
>
();
outerProduct4x4
(
reg
[
0
],
reg
[
2
],
c_v
[
0
],
c_v
[
1
],
c_v
[
2
],
c_v
[
3
]);
outerProduct4x4
(
reg
[
0
],
reg
[
3
],
c_v
[
4
],
c_v
[
5
],
c_v
[
6
],
c_v
[
7
]);
outerProduct4x4
(
reg
[
1
],
reg
[
2
],
c_v
[
8
],
c_v
[
9
],
c_v
[
10
],
c_v
[
11
]);
outerProduct4x4
(
reg
[
1
],
reg
[
3
],
c_v
[
12
],
c_v
[
13
],
c_v
[
14
],
c_v
[
15
]);
//asm volatile("\n \
//ds_read_b32 %0, %16 \n \
...
...
@@ -452,242 +636,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
//ds_read_b32 %14, %19 offset:2\n \
//ds_read_b32 %15, %19 offset:3\n \
//s_waitcnt lgkmcnt(0)"
//:
//"=v"(p_a_thread[0]),
//"=v"(p_a_thread[1]),
//"=v"(p_a_thread[2]),
//"=v"(p_a_thread[3]),
//"=v"(p_a_thread[4]),
//"=v"(p_a_thread[5]),
//"=v"(p_a_thread[6]),
//"=v"(p_a_thread[7]),
//"=v"(p_b_thread[0]),
//"=v"(p_b_thread[1]),
//"=v"(p_b_thread[2]),
//"=v"(p_b_thread[3]),
//"=v"(p_b_thread[4]),
//"=v"(p_b_thread[5]),
//"=v"(p_b_thread[6]),
//"=v"(p_b_thread[7])
//:
//"v"(__to_local((void *)(&p_a_block[0]))),
//"v"(__to_local((void *)(&p_a_block[64]))),
//"v"(__to_local((void *)(&p_b_block[0]))),
//"v"(__to_local((void *)(&p_b_block[32])))
//);
// C = A * B
asm
volatile
(
"
\n
\
v_mac_f32 %0, %64, %72
\n
\
v_mac_f32 %1, %64, %73
\n
\
v_mac_f32 %2, %64, %74
\n
\
v_mac_f32 %3, %64, %75
\n
\
v_mac_f32 %4, %64, %76
\n
\
v_mac_f32 %5, %64, %77
\n
\
v_mac_f32 %6, %64, %78
\n
\
v_mac_f32 %7, %64, %79
\n
\
v_mac_f32 %8, %65, %72
\n
\
v_mac_f32 %9, %65, %73
\n
\
v_mac_f32 %10, %65, %74
\n
\
v_mac_f32 %11, %65, %75
\n
\
v_mac_f32 %12, %65, %76
\n
\
v_mac_f32 %13, %65, %77
\n
\
v_mac_f32 %14, %65, %78
\n
\
v_mac_f32 %15, %65, %79
\n
\
v_mac_f32 %16, %66, %72
\n
\
v_mac_f32 %17, %66, %73
\n
\
v_mac_f32 %18, %66, %74
\n
\
v_mac_f32 %19, %66, %75
\n
\
v_mac_f32 %20, %66, %76
\n
\
v_mac_f32 %21, %66, %77
\n
\
v_mac_f32 %22, %66, %78
\n
\
v_mac_f32 %23, %66, %79
\n
\
v_mac_f32 %24, %67, %72
\n
\
v_mac_f32 %25, %67, %73
\n
\
v_mac_f32 %26, %67, %74
\n
\
v_mac_f32 %27, %67, %75
\n
\
v_mac_f32 %28, %67, %76
\n
\
v_mac_f32 %29, %67, %77
\n
\
v_mac_f32 %30, %67, %78
\n
\
v_mac_f32 %31, %67, %79
\n
\
v_mac_f32 %32, %68, %72
\n
\
v_mac_f32 %33, %68, %73
\n
\
v_mac_f32 %34, %68, %74
\n
\
v_mac_f32 %35, %68, %75
\n
\
v_mac_f32 %36, %68, %76
\n
\
v_mac_f32 %37, %68, %77
\n
\
v_mac_f32 %38, %68, %78
\n
\
v_mac_f32 %39, %68, %79
\n
\
v_mac_f32 %40, %69, %72
\n
\
v_mac_f32 %41, %69, %73
\n
\
v_mac_f32 %42, %69, %74
\n
\
v_mac_f32 %43, %69, %75
\n
\
v_mac_f32 %44, %69, %76
\n
\
v_mac_f32 %45, %69, %77
\n
\
v_mac_f32 %46, %69, %78
\n
\
v_mac_f32 %47, %69, %79
\n
\
v_mac_f32 %48, %70, %72
\n
\
v_mac_f32 %49, %70, %73
\n
\
v_mac_f32 %50, %70, %74
\n
\
v_mac_f32 %51, %70, %75
\n
\
v_mac_f32 %52, %70, %76
\n
\
v_mac_f32 %53, %70, %77
\n
\
v_mac_f32 %54, %70, %78
\n
\
v_mac_f32 %55, %70, %79
\n
\
v_mac_f32 %56, %71, %72
\n
\
v_mac_f32 %57, %71, %73
\n
\
v_mac_f32 %58, %71, %74
\n
\
v_mac_f32 %59, %71, %75
\n
\
v_mac_f32 %60, %71, %76
\n
\
v_mac_f32 %61, %71, %77
\n
\
v_mac_f32 %62, %71, %78
\n
\
v_mac_f32 %63, %71, %79
\n
\
"
:
"=v"
(
p_c_thread
[
0
]),
"=v"
(
p_c_thread
[
1
]),
"=v"
(
p_c_thread
[
2
]),
"=v"
(
p_c_thread
[
3
]),
"=v"
(
p_c_thread
[
4
]),
"=v"
(
p_c_thread
[
5
]),
"=v"
(
p_c_thread
[
6
]),
"=v"
(
p_c_thread
[
7
]),
"=v"
(
p_c_thread
[
8
]),
"=v"
(
p_c_thread
[
9
]),
"=v"
(
p_c_thread
[
10
]),
"=v"
(
p_c_thread
[
11
]),
"=v"
(
p_c_thread
[
12
]),
"=v"
(
p_c_thread
[
13
]),
"=v"
(
p_c_thread
[
14
]),
"=v"
(
p_c_thread
[
15
]),
"=v"
(
p_c_thread
[
16
]),
"=v"
(
p_c_thread
[
17
]),
"=v"
(
p_c_thread
[
18
]),
"=v"
(
p_c_thread
[
19
]),
"=v"
(
p_c_thread
[
20
]),
"=v"
(
p_c_thread
[
21
]),
"=v"
(
p_c_thread
[
22
]),
"=v"
(
p_c_thread
[
23
]),
"=v"
(
p_c_thread
[
24
]),
"=v"
(
p_c_thread
[
25
]),
"=v"
(
p_c_thread
[
26
]),
"=v"
(
p_c_thread
[
27
]),
"=v"
(
p_c_thread
[
28
]),
"=v"
(
p_c_thread
[
29
]),
"=v"
(
p_c_thread
[
30
]),
"=v"
(
p_c_thread
[
31
]),
"=v"
(
p_c_thread
[
32
]),
"=v"
(
p_c_thread
[
33
]),
"=v"
(
p_c_thread
[
34
]),
"=v"
(
p_c_thread
[
35
]),
"=v"
(
p_c_thread
[
36
]),
"=v"
(
p_c_thread
[
37
]),
"=v"
(
p_c_thread
[
38
]),
"=v"
(
p_c_thread
[
39
]),
"=v"
(
p_c_thread
[
40
]),
"=v"
(
p_c_thread
[
41
]),
"=v"
(
p_c_thread
[
42
]),
"=v"
(
p_c_thread
[
43
]),
"=v"
(
p_c_thread
[
44
]),
"=v"
(
p_c_thread
[
45
]),
"=v"
(
p_c_thread
[
46
]),
"=v"
(
p_c_thread
[
47
]),
"=v"
(
p_c_thread
[
48
]),
"=v"
(
p_c_thread
[
49
]),
"=v"
(
p_c_thread
[
50
]),
"=v"
(
p_c_thread
[
51
]),
"=v"
(
p_c_thread
[
52
]),
"=v"
(
p_c_thread
[
53
]),
"=v"
(
p_c_thread
[
54
]),
"=v"
(
p_c_thread
[
55
]),
"=v"
(
p_c_thread
[
56
]),
"=v"
(
p_c_thread
[
57
]),
"=v"
(
p_c_thread
[
58
]),
"=v"
(
p_c_thread
[
59
]),
"=v"
(
p_c_thread
[
60
]),
"=v"
(
p_c_thread
[
61
]),
"=v"
(
p_c_thread
[
62
]),
"=v"
(
p_c_thread
[
63
])
:
"v"
(
p_a_thread
[
0
]),
"v"
(
p_a_thread
[
1
]),
"v"
(
p_a_thread
[
2
]),
"v"
(
p_a_thread
[
3
]),
"v"
(
p_a_thread
[
4
]),
"v"
(
p_a_thread
[
5
]),
"v"
(
p_a_thread
[
6
]),
"v"
(
p_a_thread
[
7
]),
"v"
(
p_b_thread
[
0
]),
"v"
(
p_b_thread
[
1
]),
"v"
(
p_b_thread
[
2
]),
"v"
(
p_b_thread
[
3
]),
"v"
(
p_b_thread
[
4
]),
"v"
(
p_b_thread
[
5
]),
"v"
(
p_b_thread
[
6
]),
"v"
(
p_b_thread
[
7
]),
"0"
(
p_c_thread
[
0
]),
"1"
(
p_c_thread
[
1
]),
"2"
(
p_c_thread
[
2
]),
"3"
(
p_c_thread
[
3
]),
"4"
(
p_c_thread
[
4
]),
"5"
(
p_c_thread
[
5
]),
"6"
(
p_c_thread
[
6
]),
"7"
(
p_c_thread
[
7
]),
"8"
(
p_c_thread
[
8
]),
"9"
(
p_c_thread
[
9
]),
"10"
(
p_c_thread
[
10
]),
"11"
(
p_c_thread
[
11
]),
"12"
(
p_c_thread
[
12
]),
"13"
(
p_c_thread
[
13
]),
"14"
(
p_c_thread
[
14
]),
"15"
(
p_c_thread
[
15
]),
"16"
(
p_c_thread
[
16
]),
"17"
(
p_c_thread
[
17
]),
"18"
(
p_c_thread
[
18
]),
"19"
(
p_c_thread
[
19
]),
"20"
(
p_c_thread
[
20
]),
"21"
(
p_c_thread
[
21
]),
"22"
(
p_c_thread
[
22
]),
"23"
(
p_c_thread
[
23
]),
"24"
(
p_c_thread
[
24
]),
"25"
(
p_c_thread
[
25
]),
"26"
(
p_c_thread
[
26
]),
"27"
(
p_c_thread
[
27
]),
"28"
(
p_c_thread
[
28
]),
"29"
(
p_c_thread
[
29
]),
"30"
(
p_c_thread
[
30
]),
"31"
(
p_c_thread
[
31
]),
"32"
(
p_c_thread
[
32
]),
"33"
(
p_c_thread
[
33
]),
"34"
(
p_c_thread
[
34
]),
"35"
(
p_c_thread
[
35
]),
"36"
(
p_c_thread
[
36
]),
"37"
(
p_c_thread
[
37
]),
"38"
(
p_c_thread
[
38
]),
"39"
(
p_c_thread
[
39
]),
"40"
(
p_c_thread
[
40
]),
"41"
(
p_c_thread
[
41
]),
"42"
(
p_c_thread
[
42
]),
"43"
(
p_c_thread
[
43
]),
"44"
(
p_c_thread
[
44
]),
"45"
(
p_c_thread
[
45
]),
"46"
(
p_c_thread
[
46
]),
"47"
(
p_c_thread
[
47
]),
"48"
(
p_c_thread
[
48
]),
"49"
(
p_c_thread
[
49
]),
"50"
(
p_c_thread
[
50
]),
"51"
(
p_c_thread
[
51
]),
"52"
(
p_c_thread
[
52
]),
"53"
(
p_c_thread
[
53
]),
"54"
(
p_c_thread
[
54
]),
"55"
(
p_c_thread
[
55
]),
"56"
(
p_c_thread
[
56
]),
"57"
(
p_c_thread
[
57
]),
"58"
(
p_c_thread
[
58
]),
"59"
(
p_c_thread
[
59
]),
"60"
(
p_c_thread
[
60
]),
"61"
(
p_c_thread
[
61
]),
"62"
(
p_c_thread
[
62
]),
"63"
(
p_c_thread
[
63
]));
//:
//"=v"(p_a_thread[0]),
//"=v"(p_a_thread[1]),
//"=v"(p_a_thread[2]),
//"=v"(p_a_thread[3]),
//"=v"(p_a_thread[4]),
//"=v"(p_a_thread[5]),
//"=v"(p_a_thread[6]),
//"=v"(p_a_thread[7]),
//"=v"(p_b_thread[0]),
//"=v"(p_b_thread[1]),
//"=v"(p_b_thread[2]),
//"=v"(p_b_thread[3]),
//"=v"(p_b_thread[4]),
//"=v"(p_b_thread[5]),
//"=v"(p_b_thread[6]),
//"=v"(p_b_thread[7])
//:
//"v"(__to_local((void *)(&p_a_block[0]))),
//"v"(__to_local((void *)(&p_a_block[64]))),
//"v"(__to_local((void *)(&p_b_block[0]))),
//"v"(__to_local((void *)(&p_b_block[32])))
//);
//C = A * B
#else
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
View file @
114fdb58
...
...
@@ -323,7 +323,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
y
*
Wi
+
x
,
p_out_thread
,
f_accum
);
f_accum
,
Number
<
in_block_element_space
>
()
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment