Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a5181cd0
Commit
a5181cd0
authored
Mar 28, 2022
by
Shucai Xiao
Browse files
layernorm kernel optimization
parent
d5c2538c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
224 additions
and
6 deletions
+224
-6
src/targets/gpu/device/layernorm.cpp
src/targets/gpu/device/layernorm.cpp
+224
-6
No files found.
src/targets/gpu/device/layernorm.cpp
View file @
a5181cd0
...
...
@@ -2,6 +2,8 @@
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/pow.hpp>
#include <migraphx/gpu/device/fast_div.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -94,9 +96,9 @@ __device__ void layernorm(index_int i,
const
bool
in_range
=
idx
.
local
<
relements_v
;
auto
mean
=
[
&
](
auto
z
)
{
auto
m
=
auto_block_reduce
<
MaxBlockSize
>
(
idx
,
sum
{},
value_type
(
0
),
relements
_v
,
[
=
](
auto
)
{
return
z
;
})
/
value_type
(
relements
);
auto
m
=
auto_block_reduce
<
MaxBlockSize
>
(
idx
,
sum
{},
value_type
(
0
),
relements_v
,
[
=
](
auto
)
{
return
z
/
value_type
(
relements
);
}
);
#if MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC
__builtin_amdgcn_s_barrier
();
#endif
...
...
@@ -158,7 +160,7 @@ void layernorm_impl(hipStream_t stream,
const
Arguments
&
...
args
)
{
hip_visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
const
std
::
size_t
max_block_size
=
256
;
const
std
::
size_t
max_block_size
=
128
;
const
std
::
size_t
block_size
=
compute_block_size
(
relements
,
max_block_size
);
const
std
::
size_t
block_size_div
=
encode_divisor
(
block_size
);
assert
(
relements
<=
block_size
);
...
...
@@ -200,14 +202,230 @@ auto layernorm_fusion(hipStream_t stream,
};
}
struct
half2_sum
{
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
__half2
x
,
__half2
y
)
const
{
return
__hadd2
(
x
,
y
);
}
};
// in_data is in shared memory
template
<
class
Op
>
__device__
__half2
block_reduce_half2
(
__half2
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
,
Op
op
)
{
__syncthreads
();
for
(
index_int
s
=
block_size
;
s
>
0
;
s
>>=
1
)
{
if
(
tid
<
s
and
tid
+
s
<
batch_item_num
)
{
buffer
[
tid
]
=
op
(
buffer
[
tid
],
buffer
[
tid
+
s
]);
}
__syncthreads
();
}
auto
lows2
=
__low2half2
(
buffer
[
0
]);
auto
highs2
=
__high2half2
(
buffer
[
0
]);
return
op
(
lows2
,
highs2
);
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__global__
void
triadd_layernorm_kernel_half2
(
void
*
in1
,
void
*
in2
,
void
*
in3
,
void
*
data_out
,
index_int
batch_item_num
,
index_int
block_size
)
{
__half2
*
input1
=
reinterpret_cast
<
__half2
*>
(
in1
);
__half2
*
input2
=
reinterpret_cast
<
__half2
*>
(
in2
);
__half2
*
input3
=
reinterpret_cast
<
__half2
*>
(
in3
);
__half2
*
output
=
reinterpret_cast
<
__half2
*>
(
data_out
);
batch_item_num
/=
2
;
extern
MIGRAPHX_DEVICE_SHARED
__half2
buffer2
[];
__half2
*
in_data_reduce
=
buffer2
;
__half2
*
in_data
=
buffer2
+
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
auto
rnum
=
__float2half2_rn
(
1.0
f
/
batch_item_num
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
in_data
[
i
]
=
__hadd2
(
__hadd2
(
input1
[
idx
],
input2
[
idx
]),
input3
[
idx
]);
in_data_reduce
[
i
]
=
__hmul2
(
in_data
[
i
],
rnum
);
}
auto
m
=
block_reduce_half2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
half2_sum
{});
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
in_data
[
i
]
=
__hsub2
(
in_data
[
i
],
m
);
in_data_reduce
[
i
]
=
__hmul2
(
__hmul2
(
in_data
[
i
],
in_data
[
i
]),
rnum
);
}
m
=
block_reduce_half2
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
,
half2_sum
{});
auto
eps
=
__float2half2_rn
(
1.0e-12
f
);
auto
r
=
__hadd2
(
m
,
eps
);
r
=
h2rsqrt
(
r
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
output
[
idx
]
=
__hmul2
(
in_data
[
i
],
r
);
}
}
template
<
class
T
>
__device__
T
block_reduce_half
(
T
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
)
{
__syncthreads
();
for
(
index_int
s
=
block_size
;
s
>
0
;
s
>>=
1
)
{
if
(
tid
<
s
and
tid
+
s
<
batch_item_num
)
{
buffer
[
tid
]
=
__float2half
(
__half2float
(
buffer
[
tid
])
+
__half2float
(
buffer
[
tid
+
s
]));
}
__syncthreads
();
}
return
buffer
[
0
];
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__global__
void
triadd_layernorm_kernel_half
(
void
*
in1
,
void
*
in2
,
void
*
in3
,
void
*
data_out
,
index_int
batch_item_num
,
index_int
block_size
)
{
__half
*
input1
=
reinterpret_cast
<
__half
*>
(
in1
);
__half
*
input2
=
reinterpret_cast
<
__half
*>
(
in2
);
__half
*
input3
=
reinterpret_cast
<
__half
*>
(
in3
);
__half
*
output
=
reinterpret_cast
<
__half
*>
(
data_out
);
extern
MIGRAPHX_DEVICE_SHARED
__half
bufferh
[];
__half
*
in_data_reduce
=
bufferh
;
__half
*
in_data
=
bufferh
+
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
auto
rnum
=
1.0
f
/
batch_item_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
in_data
[
i
]
=
__float2half
(
__half2float
(
input1
[
idx
])
+
__half2float
(
input2
[
idx
])
+
__half2float
(
input3
[
idx
]));
in_data_reduce
[
i
]
=
__float2half
(
__half2float
(
in_data
[
i
])
*
__half2float
(
rnum
));
}
auto
m
=
block_reduce_half
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
in_data
[
i
]
=
__float2half
(
__half2float
(
in_data
[
i
])
-
__half2float
(
m
));
in_data_reduce
[
i
]
=
__float2half
(
__half2float
(
in_data
[
i
])
*
__half2float
(
in_data
[
i
])
*
__half2float
(
rnum
));
}
m
=
__float2half
(
__half2float
(
block_reduce_half
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
))
+
1.0e-12
f
);
auto
r
=
__float2half
(
rsqrt
(
__half2float
(
m
)));
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
output
[
idx
]
=
__float2half
(
__half2float
(
in_data
[
i
])
*
__half2float
(
r
));
}
}
template
<
class
T
>
__device__
T
block_reduce
(
T
*
buffer
,
index_int
batch_item_num
,
index_int
tid
,
index_int
block_size
)
{
__syncthreads
();
for
(
index_int
s
=
block_size
;
s
>
0
;
s
>>=
1
)
{
if
(
tid
<
s
and
tid
+
s
<
batch_item_num
)
{
buffer
[
tid
]
=
buffer
[
tid
]
+
buffer
[
tid
+
s
];
}
__syncthreads
();
}
return
buffer
[
0
];
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
template
<
class
T
>
__global__
void
triadd_layernorm_kernel
(
void
*
in1
,
void
*
in2
,
void
*
in3
,
void
*
data_out
,
index_int
batch_item_num
,
index_int
block_size
)
{
T
*
input1
=
reinterpret_cast
<
T
*>
(
in1
);
T
*
input2
=
reinterpret_cast
<
T
*>
(
in2
);
T
*
input3
=
reinterpret_cast
<
T
*>
(
in3
);
T
*
output
=
reinterpret_cast
<
T
*>
(
data_out
);
extern
MIGRAPHX_DEVICE_SHARED
T
buffer
[];
T
*
in_data_reduce
=
buffer
;
T
*
in_data
=
buffer
+
batch_item_num
;
int
start
=
blockIdx
.
x
*
batch_item_num
;
auto
rnum
=
1.0
f
/
batch_item_num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
in_data
[
i
]
=
input1
[
idx
]
+
input2
[
idx
]
+
input3
[
idx
];
in_data_reduce
[
i
]
=
in_data
[
i
]
*
rnum
;
}
auto
m
=
block_reduce
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
in_data
[
i
]
=
in_data
[
i
]
-
m
;
in_data_reduce
[
i
]
=
in_data
[
i
]
*
in_data
[
i
]
*
rnum
;
}
m
=
block_reduce
(
in_data_reduce
,
batch_item_num
,
threadIdx
.
x
,
block_size
)
+
1.0e-12
f
;
auto
r
=
rsqrt
(
m
);
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_item_num
;
i
+=
block_size
)
{
int
idx
=
i
+
start
;
output
[
idx
]
=
in_data
[
i
]
*
r
;
}
}
void
triadd_layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
layernorm_fusion
(
stream
,
result
,
arg1
,
arg2
,
arg3
)(
[](
auto
x
,
auto
y
,
auto
z
)
{
return
x
+
y
+
z
;
},
[](
auto
x
,
auto
&
y
,
auto
...)
{
y
=
x
;
});
auto
in_s
=
arg1
.
get_shape
();
auto
type
=
in_s
.
type
();
auto
batch_item_num
=
in_s
.
lens
().
back
();
if
(
type
==
shape
::
half_type
and
(
batch_item_num
%
2
)
==
0
)
{
auto
half2_block_size
=
compute_block_size
(
batch_item_num
,
1024
);
int
block_num
=
in_s
.
elements
()
/
batch_item_num
;
int
shared_size
=
batch_item_num
*
2
*
in_s
.
type_size
();
half2_block_size
=
half2_block_size
/
4
;
triadd_layernorm_kernel_half2
<<<
block_num
,
half2_block_size
,
shared_size
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
result
.
data
(),
batch_item_num
,
half2_block_size
);
}
// if(type == shape::half_type and (batch_item_num % 2) == 0)
// {
// auto reduce_block_size = compute_block_size(batch_item_num, 1024);
// int block_num = in_s.elements() / batch_item_num;
// int shared_size = batch_item_num * 2 * in_s.type_size();
// reduce_block_size = reduce_block_size / 2;
// triadd_layernorm_kernel_half<<<block_num, reduce_block_size, shared_size, stream>>>(
// arg1.data(),
// arg2.data(),
// arg3.data(),
// result.data(),
// batch_item_num,
// reduce_block_size);
// }
else
{
layernorm_fusion
(
stream
,
result
,
arg1
,
arg2
,
arg3
)(
[](
auto
x
,
auto
y
,
auto
z
)
{
return
x
+
y
+
z
;
},
[](
auto
x
,
auto
&
y
,
auto
...)
{
y
=
x
;
});
}
}
void
layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment