Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2d5e45b8
Commit
2d5e45b8
authored
Mar 01, 2022
by
Shucai Xiao
Browse files
backup kernel refinement for add, mul, and mul_add
parent
16e5b5d0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
106 additions
and
113 deletions
+106
-113
src/targets/gpu/device/add.cpp
src/targets/gpu/device/add.cpp
+26
-9
src/targets/gpu/device/mul.cpp
src/targets/gpu/device/mul.cpp
+27
-20
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+53
-84
No files found.
src/targets/gpu/device/add.cpp
View file @
2d5e45b8
...
...
@@ -8,27 +8,44 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
__global__
void
add_kernel
(
__half
*
a
,
__half
*
b
,
__half
*
r
,
int
n
)
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
2
)
{
auto
stride
=
ss
.
at
(
1
).
strides
();
return
(
stride
[
0
]
==
0
);
}
return
false
;
}
__global__
void
add_kernel
(
void
*
a
,
void
*
b
,
int
n_dim
,
void
*
r
,
int
n
)
{
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
n
)
{
r
[
tid
]
=
a
[
tid
]
+
b
[
tid
%
768
];
int
idb
=
tid
%
n_dim
;
hr
[
tid
]
=
__hadd2
(
ha
[
tid
],
hb
[
idb
]);
}
}
void
add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
auto
s2
=
arg2
.
get_shape
();
if
(
s2
.
element_space
()
==
768
and
s2
.
type
()
==
shape
::
half_type
)
auto
sr
=
result
.
get_shape
();
std
::
vector
<
shape
>
ss
;
ss
.
push_back
(
arg1
.
get_shape
());
ss
.
push_back
(
arg2
.
get_shape
());
if
(
sr
.
type
()
==
shape
::
half_type
and
is_bert
(
ss
))
{
auto
elem_num
=
s2
.
elements
();
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
add_kernel
<<<
block_num
,
block_size
>>>
(
reinterpret_cast
<
__half
*>
(
arg1
.
data
()),
reinterpret_cast
<
__half
*>
(
arg2
.
data
()),
reinterpret_cast
<
__half
*>
(
result
.
data
()),
elem_num
);
add_kernel
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
...
...
src/targets/gpu/device/mul.cpp
View file @
2d5e45b8
...
...
@@ -8,44 +8,51 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
__global__
void
mul_kernel
(
__half
*
a
,
__half
*
b
,
__half
*
r
,
int
n
)
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
2
)
{
auto
stride
=
ss
.
at
(
1
).
strides
();
return
(
stride
[
0
]
==
0
);
}
return
false
;
}
__global__
void
mul_kernel
(
void
*
a
,
void
*
b
,
int
n_dim
,
void
*
r
,
int
n
)
{
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
n
)
{
r
[
tid
]
=
a
[
tid
]
*
b
[
tid
%
768
];
int
idb
=
tid
%
n_dim
;
hr
[
tid
]
=
__hmul2
(
ha
[
tid
],
hb
[
idb
]);
}
}
void
mul
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
auto
s2
=
arg2
.
get_shape
();
if
(
s2
.
element_space
()
==
768
and
s2
.
type
()
==
shape
::
half_type
)
auto
sr
=
result
.
get_shape
();
std
::
vector
<
shape
>
ss
;
ss
.
push_back
(
arg1
.
get_shape
());
ss
.
push_back
(
arg2
.
get_shape
());
if
(
sr
.
type
()
==
shape
::
half_type
and
is_bert
(
ss
))
{
auto
elem_num
=
s2
.
elements
();
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
mul_kernel
<<<
block_num
,
block_size
>>>
(
reinterpret_cast
<
__half
*>
(
arg1
.
data
()),
reinterpret_cast
<
__half
*>
(
arg2
.
data
()),
reinterpret_cast
<
__half
*>
(
result
.
data
()),
elem_num
);
mul_kernel
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
return
x
*
y
;
});
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
return
x
+
y
;
});
}
}
void
mul
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
y
,
auto
z
)
__device__
{
return
x
*
y
*
z
;
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/device/mul_add.cpp
View file @
2d5e45b8
...
...
@@ -11,84 +11,51 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
//__global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int n)
//{
// int id = blockDim.x * blockIdx.x + threadIdx.x;
// __half* ha = reinterpret_cast<__half*>(a);
// __half* hb = reinterpret_cast<__half*>(b);
// __half* hx = reinterpret_cast<__half*>(x);
// __half* hr = reinterpret_cast<__half*>(r);
// if (id < n)
// {
// hr[id] = __float2half(__half2float(ha[id]) * __half2float(hx[id]) + __half2float(hb[id]));
// }
//}
// __global__ void mul_add_kernel(void* a, int an, void* x, int xn, void* b, int bn, void* r, int n)
// {
// int id = blockDim.x * blockIdx.x + threadIdx.x;
// __half2* ha = reinterpret_cast<__half2*>(a);
// __half2* hb = reinterpret_cast<__half2*>(b);
// __half2* hx = reinterpret_cast<__half2*>(x);
// __half2* hr = reinterpret_cast<__half2*>(r);
// if(id < n)
// {
// hr[id] = __hadd2(__hmul2(ha[id % an], hx[id % xn]), hb[id % bn]);
// }
// }
__global__
void
mul_add_kernel
(
void
*
a
,
void
*
x
,
void
*
b
,
void
*
r
,
int
*
strides
,
int
elem_num
)
__global__
void
mul_add_kernel_dim3
(
void
*
a
,
void
*
x
,
void
*
b
,
int
dim3
,
void
*
r
,
int
n
)
{
__shared__
int
shared_strides
[
18
];
int
tid
=
threadIdx
.
x
*
(
blockDim
.
y
*
blockDim
.
z
)
+
threadIdx
.
y
*
blockDim
.
z
+
threadIdx
.
z
;
if
(
tid
<
18
)
int
id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hx
=
reinterpret_cast
<
__half2
*>
(
x
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
if
(
id
<
n
)
{
shared_strides
[
tid
]
=
strides
[
tid
];
auto
id1
=
id
%
dim3
;
hr
[
id
]
=
__hadd2
(
__hmul2
(
ha
[
id
],
hx
[
id1
]),
hb
[
id1
]);
}
__syncthreads
();
}
__global__
void
mul_add_kernel_dim4
(
void
*
a
,
void
*
x
,
void
*
b
,
int
factor
,
int
dim4
,
void
*
r
,
int
n
)
{
int
id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hx
=
reinterpret_cast
<
__half2
*>
(
x
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
tid
=
tid
+
(
blockIdx
.
x
*
(
gridDim
.
y
*
gridDim
.
z
)
+
blockIdx
.
y
*
gridDim
.
z
+
blockIdx
.
z
)
*
blockDim
.
x
*
blockDim
.
y
*
blockDim
.
z
;
if
(
tid
<
elem_num
)
if
(
id
<
n
)
{
int
tida
=
shared_strides
[
1
]
*
blockIdx
.
x
+
shared_strides
[
2
]
*
blockIdx
.
y
+
shared_strides
[
3
]
*
blockIdx
.
z
+
shared_strides
[
4
]
*
threadIdx
.
x
+
shared_strides
[
5
]
*
threadIdx
.
y
+
threadIdx
.
z
;
int
tidx
=
shared_strides
[
7
]
*
blockIdx
.
x
+
shared_strides
[
8
]
*
blockIdx
.
y
+
shared_strides
[
9
]
*
blockIdx
.
z
+
shared_strides
[
10
]
*
threadIdx
.
x
+
shared_strides
[
11
]
*
threadIdx
.
y
+
threadIdx
.
z
;
int
tidb
=
shared_strides
[
13
]
*
blockIdx
.
x
+
shared_strides
[
14
]
*
blockIdx
.
y
+
shared_strides
[
15
]
*
blockIdx
.
z
+
shared_strides
[
16
]
*
threadIdx
.
x
+
shared_strides
[
17
]
*
threadIdx
.
y
+
threadIdx
.
z
;
hr
[
tid
]
=
__hadd2
(
__hmul2
(
ha
[
tida
],
hx
[
tidx
]),
hb
[
tidb
]);
int
idb
=
id
/
factor
+
id
%
dim4
;
hr
[
id
]
=
__hadd2
(
__hmul2
(
ha
[
id
],
hx
[
id
]),
hb
[
idb
]);
}
}
// void mul_add(hipStream_t stream,
// const argument& result,
// const argument& arg1,
// const argument& arg2,
// const argument& arg3)
// {
// auto type = result.get_shape().type();
// if(type == shape::half_type)
// {
// std::cout << "case1" << std::endl;
// mul_add_kernel<<<block_num, block_size>>>(
// arg1.data(), s1e, arg2.data(), s2e, arg3.data(), s3e, result.data(), elem_num);
// }
// else
// {
// std::cout << "mul_add" << std::endl;
// nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
// __device__ { return a * x + b; });
// }
// }
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
3
)
{
auto
stride
=
ss
.
at
(
2
).
strides
();
return
(
stride
[
1
]
==
0
);
}
else
if
(
n_dim
==
2
)
{
auto
stride1
=
ss
.
at
(
1
).
strides
();
auto
stride2
=
ss
.
at
(
2
).
strides
();
return
(
stride1
==
stride2
and
stride1
[
0
]
==
0
);
}
return
false
;
}
void
mul_add
(
hipStream_t
stream
,
const
argument
&
result
,
...
...
@@ -97,27 +64,29 @@ void mul_add(hipStream_t stream,
const
argument
&
arg3
)
{
auto
sr
=
result
.
get_shape
();
auto
s1
=
arg1
.
get_shape
();
auto
s2
=
arg2
.
get_shape
();
auto
s3
=
arg3
.
get_shape
();
auto
type
=
sr
.
type
();
if
(
type
==
sr
.
type
())
std
::
vector
<
shape
>
ss
;
ss
.
push_back
(
arg1
.
get_shape
());
ss
.
push_back
(
arg2
.
get_shape
());
ss
.
push_back
(
arg3
.
get_shape
());
if
(
type
==
shape
::
half_type
and
is_bert
(
ss
))
{
hip_visit_all
(
result
,
arg1
,
arg2
,
arg3
,
sr
,
s1
,
s2
,
s3
)(
[
&
](
auto
r
,
auto
i1
,
auto
i2
,
auto
i3
,
auto
dsr
,
auto
ds1
,
auto
ds2
,
auto
ds3
)
{
__half2
*
rp
=
reinterpret_cast
<
__half2
*>
(
r
.
data
());
__half2
*
i1p
=
reinterpret_cast
<
__half2
*>
(
i1
.
data
());
__half2
*
i2p
=
reinterpret_cast
<
__half2
*>
(
i2
.
data
());
__half2
*
i3p
=
reinterpret_cast
<
__half2
*>
(
i3
.
data
());
gs_launch
(
stream
,
sr
.
elements
()
/
2
)([
=
](
auto
i
)
__device__
{
auto
idx
=
dsr
.
multi
(
i
);
auto
idx1
=
ds1
.
index
(
idx
);
auto
idx2
=
ds2
.
index
(
idx
);
auto
idx3
=
ds3
.
index
(
idx
);
rp
[
i
]
=
__hadd2
(
__hmul2
(
i1p
[
idx1
],
i2p
[
idx2
]),
i3p
[
idx3
]);
});
});
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
lens
=
sr
.
lens
();
int
last_dim
=
lens
.
back
()
/
2
;
auto
n_dim
=
lens
.
size
();
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
if
(
n_dim
==
2
)
{
mul_add_kernel_dim3
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
int
factor
=
lens
[
1
];
mul_add_kernel_dim4
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
factor
,
last_dim
,
result
.
data
(),
elem_num
);
}
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment