Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8a5f8759
Commit
8a5f8759
authored
Mar 31, 2022
by
Shucai Xiao
Browse files
merge changes from the half2 branch
parents
86a03f28
48b39e06
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
314 additions
and
34 deletions
+314
-34
src/targets/gpu/device/add.cpp
src/targets/gpu/device/add.cpp
+7
-1
src/targets/gpu/device/gelu.cpp
src/targets/gpu/device/gelu.cpp
+7
-1
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+6
-5
src/targets/gpu/device/layernorm.cpp
src/targets/gpu/device/layernorm.cpp
+282
-26
src/targets/gpu/device/mul.cpp
src/targets/gpu/device/mul.cpp
+5
-0
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+7
-1
No files found.
src/targets/gpu/device/add.cpp
View file @
8a5f8759
...
...
@@ -9,7 +9,13 @@ namespace gpu {
namespace
device
{
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
{
auto
last_dim
=
ss
.
front
().
lens
().
back
();
if
(
last_dim
%
2
!=
0
)
{
return
false
;
}
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
2
)
{
...
...
src/targets/gpu/device/gelu.cpp
View file @
8a5f8759
...
...
@@ -35,7 +35,13 @@ void gelu_new(hipStream_t stream, const argument& result, const argument& arg)
}
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
{
auto
last_dim
=
ss
.
front
().
lens
().
back
();
if
(
last_dim
%
2
!=
0
)
{
return
false
;
}
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
2
)
{
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
8a5f8759
...
...
@@ -176,11 +176,12 @@ template <index_int N, class T, class... Ts>
auto
hip_vec_visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
{
return
[
&
](
auto
f
)
{
hip_visit_all_impl
(
get_shape
(
x
),
make_hip_convert
([](
auto
*
p
)
{
return
as_vec
<
N
>
(
device_cast
(
p
));
}),
f
,
x
,
xs
...);
auto
sx
=
get_shape
(
x
);
auto
lens
=
sx
.
lens
();
lens
.
back
()
/=
N
;
shape
ssx
{
sx
.
type
(),
lens
};
hip_visit_all_impl
(
ssx
,
make_hip_convert
([](
auto
*
p
)
{
return
as_vec
<
N
>
(
device_cast
(
p
));
}),
f
,
x
,
xs
...);
};
}
...
...
src/targets/gpu/device/layernorm.cpp
View file @
8a5f8759
...
...
@@ -2,6 +2,8 @@
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/pow.hpp>
#include <migraphx/gpu/device/fast_div.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -79,24 +81,19 @@ __device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f)
}
template
<
index_int
MaxBlockSize
,
class
Input
,
class
Output
>
__device__
void
layernorm
(
index_int
i
,
index
idx
,
std
::
size_t
block_size_div
,
index_int
relements
,
Input
input
,
Output
output
)
__device__
void
layernorm
(
index
idx
,
index_int
relements
,
Input
input
,
Output
output
)
{
using
value_type
=
decltype
(
input
(
idx
.
local
));
const
auto
relements_v
=
relements
/
vector_size
<
value_type
>
{};
const
auto
out_idx
=
fast_div
(
i
,
block_size_div
)
;
const
auto
out_idx
=
blockIdx
.
x
;
const
auto
base_idx
=
out_idx
*
relements_v
;
const
auto
input_idx
=
base_idx
+
idx
.
local
;
const
bool
in_range
=
idx
.
local
<
relements_v
;
auto
mean
=
[
&
](
auto
z
)
{
auto
m
=
auto_block_reduce
<
MaxBlockSize
>
(
idx
,
sum
{},
value_type
(
0
),
relements
_v
,
[
=
](
auto
)
{
return
z
;
})
/
value_type
(
relements
);
auto
m
=
auto_block_reduce
<
MaxBlockSize
>
(
idx
,
sum
{},
value_type
(
0
),
relements_v
,
[
=
](
auto
)
{
return
z
/
value_type
(
relements
);
}
);
#if MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC
__builtin_amdgcn_s_barrier
();
#endif
...
...
@@ -131,14 +128,11 @@ void layernorm_vec_impl(hipStream_t stream,
const
auto
relements_v
=
relements
/
N
;
const
std
::
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
relements_v
,
max_block_size
);
const
std
::
size_t
block_size_div
=
encode_divisor
(
block_size
);
assert
(
relements_v
<=
block_size
);
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
,
auto
idx
)
__device__
{
layernorm
<
max_block_size
>
(
i
,
idx
,
block_size_div
,
relements
,
[
&
](
auto
input_idx
)
{
return
in
(
inputs
.
data
()[
input_idx
]...);
},
[
&
](
auto
input_idx
,
auto
x
)
{
...
...
@@ -160,14 +154,11 @@ void layernorm_impl(hipStream_t stream,
hip_visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
const
std
::
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
relements
,
max_block_size
);
const
std
::
size_t
block_size_div
=
encode_divisor
(
block_size
);
assert
(
relements
<=
block_size
);
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
,
auto
idx
)
__device__
{
layernorm
<
max_block_size
>
(
i
,
idx
,
block_size_div
,
relements
,
[
&
](
auto
input_idx
)
{
return
in
(
inputs
.
data
()[
input_idx
]...);
},
[
&
](
auto
input_idx
,
auto
x
)
{
...
...
@@ -186,10 +177,6 @@ auto layernorm_fusion(hipStream_t stream,
return
[
=
](
auto
input
,
auto
output
)
{
auto
relements
=
arg1
.
get_shape
().
lens
().
back
();
auto
nelements
=
result
.
get_shape
().
elements
()
/
relements
;
// auto output_shape = result.get_shape();
// auto reduce_output_lens(output_shape.lens());
// reduce_output_lens.back() = 1;
if
((
relements
%
4
)
==
0
)
layernorm_vec_impl
<
4
>
(
stream
,
nelements
,
relements
,
input
,
output
,
result
,
arg1
,
args
...);
...
...
@@ -200,20 +187,289 @@ auto layernorm_fusion(hipStream_t stream,
};
}
struct
half2_sum
{
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
__half2
x
,
__half2
y
)
const
{
return
__hadd2
(
x
,
y
);
}
};
// in_data is in shared memory
template
<
class
Op
>
__device__
__half2
block_reduce_half2
(
__half2
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
{
__syncthreads
();
for
(
index_int
s
=
block_size
;
s
>
0
;
s
>>=
1
)
{
if
(
tid
<
s
and
tid
+
s
<
batch_item_num
)
{
buffer
[
tid
]
=
op
(
buffer
[
tid
],
buffer
[
tid
+
s
]);
}
__syncthreads
();
}
auto
lows2
=
__low2half2
(
buffer
[
0
]);
auto
highs2
=
__high2half2
(
buffer
[
0
]);
return
op
(
lows2
,
highs2
);
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__global__
void
triadd_layernorm_kernel_half2
(
void
*
in1
,
void
*
in2
,
void
*
in3
,
void
*
data_out
,
index_int
batch_item_num
,
index_int
block_size
)
{
__half2
*
input1
=
reinterpret_cast
<
__half2
*>
(
in1
);
__half2
*
input2
=
reinterpret_cast
<
__half2
*>
(
in2
);
__half2
*
input3
=
reinterpret_cast
<
__half2
*>
(
in3
);
__half2
*
output
=
reinterpret_cast
<
__half2
*>
(
data_out
);
auto
rnum
=
__float2half2_rn
(
1.0
f
/
batch_item_num
);
batch_item_num
/=
2
;
extern
MIGRAPHX_DEVICE_SHARED
__half2
buffer2
[];
__half2
*
in_data_reduce
=
buffer2
;
__half2
*
in_data
=
buffer2
+
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
in_data
[
i
]
=
__hadd2
(
__hadd2
(
input1
[
idx
],
input2
[
idx
]),
input3
[
idx
]);
in_data_reduce
[
i
]
=
in_data
[
i
];
// in_data_reduce[i] = __hmul2(in_data[i], rnum);
}
auto
m
=
block_reduce_half2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
half2_sum
{});
m
=
__hmul2
(
m
,
rnum
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
in_data
[
i
]
=
__hsub2
(
in_data
[
i
],
m
);
// in_data_reduce[i] = __hmul2(__hmul2(in_data[i], in_data[i]), rnum);
in_data_reduce
[
i
]
=
__hmul2
(
in_data
[
i
],
in_data
[
i
]);
}
m
=
block_reduce_half2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
half2_sum
{});
m
=
__hmul2
(
m
,
rnum
);
auto
eps
=
__float2half2_rn
(
1.0e-12
f
);
auto
r
=
__hadd2
(
m
,
eps
);
r
=
h2rsqrt
(
r
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
output
[
idx
]
=
__hmul2
(
in_data
[
i
],
r
);
}
}
template
<
class
T
>
__device__
T
block_reduce_half
(
T
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
)
{
__syncthreads
();
for
(
index_int
s
=
block_size
;
s
>
0
;
s
>>=
1
)
{
if
(
tid
<
s
and
tid
+
s
<
batch_item_num
)
{
buffer
[
tid
]
=
__float2half
(
__half2float
(
buffer
[
tid
])
+
__half2float
(
buffer
[
tid
+
s
]));
}
__syncthreads
();
}
return
buffer
[
0
];
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__global__
void
triadd_layernorm_kernel_half
(
void
*
in1
,
void
*
in2
,
void
*
in3
,
void
*
data_out
,
index_int
batch_item_num
,
index_int
block_size
)
{
__half
*
input1
=
reinterpret_cast
<
__half
*>
(
in1
);
__half
*
input2
=
reinterpret_cast
<
__half
*>
(
in2
);
__half
*
input3
=
reinterpret_cast
<
__half
*>
(
in3
);
__half
*
output
=
reinterpret_cast
<
__half
*>
(
data_out
);
extern
MIGRAPHX_DEVICE_SHARED
__half
bufferh
[];
__half
*
in_data_reduce
=
bufferh
;
__half
*
in_data
=
bufferh
+
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
auto
rnum
=
1.0
f
/
batch_item_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
in_data
[
i
]
=
__float2half
(
__half2float
(
input1
[
idx
])
+
__half2float
(
input2
[
idx
])
+
__half2float
(
input3
[
idx
]));
in_data_reduce
[
i
]
=
__float2half
(
__half2float
(
in_data
[
i
])
*
__half2float
(
rnum
));
}
auto
m
=
block_reduce_half
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
in_data
[
i
]
=
__float2half
(
__half2float
(
in_data
[
i
])
-
__half2float
(
m
));
in_data_reduce
[
i
]
=
__float2half
(
__half2float
(
in_data
[
i
])
*
__half2float
(
in_data
[
i
])
*
__half2float
(
rnum
));
}
m
=
__float2half
(
__half2float
(
block_reduce_half
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
))
+
1.0e-12
f
);
auto
r
=
__float2half
(
rsqrt
(
__half2float
(
m
)));
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
output
[
idx
]
=
__float2half
(
__half2float
(
in_data
[
i
])
*
__half2float
(
r
));
}
}
template
<
class
T
>
__device__
T
block_reduce
(
T
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
)
{
__syncthreads
();
for
(
index_int
s
=
block_size
;
s
>
0
;
s
>>=
1
)
{
if
(
tid
<
s
and
tid
+
s
<
batch_item_num
)
{
buffer
[
tid
]
=
buffer
[
tid
]
+
buffer
[
tid
+
s
];
}
__syncthreads
();
}
return
buffer
[
0
];
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
template
<
class
T
>
__global__
void
triadd_layernorm_kernel
(
void
*
in1
,
void
*
in2
,
void
*
in3
,
void
*
data_out
,
index_int
batch_item_num
,
index_int
block_size
)
{
T
*
input1
=
reinterpret_cast
<
T
*>
(
in1
);
T
*
input2
=
reinterpret_cast
<
T
*>
(
in2
);
T
*
input3
=
reinterpret_cast
<
T
*>
(
in3
);
T
*
output
=
reinterpret_cast
<
T
*>
(
data_out
);
extern
MIGRAPHX_DEVICE_SHARED
T
buffer
[];
T
*
in_data_reduce
=
buffer
;
T
*
in_data
=
buffer
+
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
auto
rnum
=
1.0
f
/
batch_item_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
in_data
[
i
]
=
input1
[
idx
]
+
input2
[
idx
]
+
input3
[
idx
];
in_data_reduce
[
i
]
=
in_data
[
i
];
// in_data_reduce[i] = __half2float(in_data[i]) * rnum;
}
auto
m
=
block_reduce
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
);
m
=
m
*
rnum
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
in_data
[
i
]
=
in_data
[
i
]
-
m
;
in_data_reduce
[
i
]
=
in_data
[
i
]
*
in_data
[
i
];
// in_data_reduce[i] = __half2float(in_data[i] * in_data[i]) * rnum;
}
m
=
block_reduce
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
);
m
=
m
*
rnum
+
1.0e-12
f
;
auto
r
=
rsqrt
(
m
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
// output[idx] = __half2float(in_data[i]) * r;
output
[
idx
]
=
in_data
[
i
]
*
r
;
}
}
void
triadd_layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
layernorm_fusion
(
stream
,
result
,
arg1
,
arg2
,
arg3
)(
[](
auto
x
,
auto
y
,
auto
z
)
{
return
x
+
y
+
z
;
},
[](
auto
x
,
auto
&
y
,
auto
...)
{
y
=
x
;
});
auto
in_s
=
arg1
.
get_shape
();
auto
type
=
in_s
.
type
();
auto
batch_item_num
=
in_s
.
lens
().
back
();
if
(
type
==
shape
::
half_type
and
(
batch_item_num
%
2
)
==
0
)
{
auto
half2_block_size
=
compute_block_size
(
batch_item_num
,
1024
);
int
block_num
=
in_s
.
elements
()
/
batch_item_num
;
int
shared_size
=
batch_item_num
*
2
*
in_s
.
type_size
();
half2_block_size
=
half2_block_size
/
4
;
triadd_layernorm_kernel_half2
<<<
block_num
,
half2_block_size
,
shared_size
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
result
.
data
(),
batch_item_num
,
half2_block_size
);
}
else
{
layernorm_fusion
(
stream
,
result
,
arg1
,
arg2
,
arg3
)(
[](
auto
x
,
auto
y
,
auto
z
)
{
return
x
+
y
+
z
;
},
[](
auto
x
,
auto
&
y
,
auto
...)
{
y
=
x
;
});
}
}
__global__
void
layernorm_kernel_half2
(
void
*
in1
,
void
*
data_out
,
index_int
batch_item_num
,
index_int
block_size
)
{
__half2
*
input1
=
reinterpret_cast
<
__half2
*>
(
in1
);
__half2
*
output
=
reinterpret_cast
<
__half2
*>
(
data_out
);
auto
rnum
=
__float2half2_rn
(
1.0
f
/
batch_item_num
);
batch_item_num
/=
2
;
extern
MIGRAPHX_DEVICE_SHARED
__half2
buffer2
[];
__half2
*
in_data_reduce
=
buffer2
;
__half2
*
in_data
=
buffer2
+
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
in_data
[
i
]
=
input1
[
idx
];
in_data_reduce
[
i
]
=
in_data
[
i
];
}
auto
m
=
block_reduce_half2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
half2_sum
{});
m
=
__hmul2
(
m
,
rnum
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
in_data
[
i
]
=
__hsub2
(
in_data
[
i
],
m
);
in_data_reduce
[
i
]
=
__hmul2
(
in_data
[
i
],
in_data
[
i
]);
}
m
=
block_reduce_half2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
half2_sum
{});
m
=
__hmul2
(
m
,
rnum
);
auto
eps
=
__float2half2_rn
(
1.0e-12
f
);
auto
r
=
__hadd2
(
m
,
eps
);
r
=
h2rsqrt
(
r
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
output
[
idx
]
=
__hmul2
(
in_data
[
i
],
r
);
}
}
void
layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
)
{
layernorm_fusion
(
stream
,
result
,
arg1
)([](
auto
x
)
{
return
x
;
},
[](
auto
x
,
auto
&
y
,
auto
)
{
y
=
x
;
});
auto
in_s
=
arg1
.
get_shape
();
auto
type
=
in_s
.
type
();
auto
batch_item_num
=
in_s
.
lens
().
back
();
if
(
type
==
shape
::
half_type
and
(
batch_item_num
%
2
)
==
0
)
{
auto
half2_block_size
=
compute_block_size
(
batch_item_num
,
1024
);
int
block_num
=
in_s
.
elements
()
/
batch_item_num
;
int
shared_size
=
batch_item_num
*
2
*
in_s
.
type_size
();
half2_block_size
=
half2_block_size
/
4
;
layernorm_kernel_half2
<<<
block_num
,
half2_block_size
,
shared_size
,
stream
>>>
(
arg1
.
data
(),
result
.
data
(),
batch_item_num
,
half2_block_size
);
}
else
{
layernorm_fusion
(
stream
,
result
,
arg1
)([](
auto
x
)
{
return
x
;
},
[](
auto
x
,
auto
&
y
,
auto
)
{
y
=
x
;
});
}
}
}
// namespace device
...
...
src/targets/gpu/device/mul.cpp
View file @
8a5f8759
...
...
@@ -10,6 +10,11 @@ namespace device {
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
auto
last_dim
=
ss
.
front
().
lens
().
back
();
if
(
last_dim
%
2
!=
0
)
{
return
false
;
}
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
2
)
{
...
...
src/targets/gpu/device/mul_add.cpp
View file @
8a5f8759
...
...
@@ -40,7 +40,13 @@ __global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int d
}
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
{
auto
last_dim
=
ss
.
front
().
lens
().
back
();
if
(
last_dim
%
2
!=
0
)
{
return
false
;
}
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
3
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment