Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
dfe423ae
Commit
dfe423ae
authored
Mar 31, 2022
by
BoxiangW
Committed by
binmakeswell
Apr 06, 2022
Browse files
fix format (#572)
parent
cfb41297
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
63 deletions
+73
-63
colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
...alai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
+73
-63
No files found.
colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
View file @
dfe423ae
...
@@ -7,8 +7,7 @@ namespace cg = cooperative_groups;
...
@@ -7,8 +7,7 @@ namespace cg = cooperative_groups;
const
float
LN_EPSILON
=
1e-8
f
;
const
float
LN_EPSILON
=
1e-8
f
;
#define TILE_DIM 32
#define TILE_DIM 32
template
<
typename
T
>
template
<
typename
T
>
__forceinline__
__device__
T
add_eps
(
T
x
)
{
__forceinline__
__device__
T
add_eps
(
T
x
)
{
return
fabsf
(
x
)
>
LN_EPSILON
?
x
:
(
x
<
0
?
-
LN_EPSILON
:
LN_EPSILON
);
return
fabsf
(
x
)
>
LN_EPSILON
?
x
:
(
x
<
0
?
-
LN_EPSILON
:
LN_EPSILON
);
}
}
...
@@ -138,13 +137,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
...
@@ -138,13 +137,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars,
// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// __half *means, const __half *inp,
// const __half *scale, const __half
*bias,
// const __half *scale, const __half
// int hidden_size) {
//
*bias,
int hidden_size) {
// // step 0. compute local sum
// // step 0. compute local sum
// float l_sum = 0;
// float l_sum = 0;
// float l_square_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// __half2 *val_h2 = (__half2 *)(&val_f4);
...
@@ -154,7 +154,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
...
@@ -154,7 +154,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y;
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x
// * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// }
// }
// }
// }
...
@@ -176,7 +177,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
...
@@ -176,7 +177,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// // step 2. layer norm result
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2;
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// // load scale, bias, input
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
...
@@ -202,9 +204,9 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
...
@@ -202,9 +204,9 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_h2[i] = __float22half2_rn(val_f2);
// val_h2[i] = __float22half2_rn(val_f2);
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
bias_f2_1.x;
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
+ bias_f2_1.y;
//
bias_f2_1.x;
val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// val_h2_1[i] = __float22half2_rn(val_f2_1);
//
+ bias_f2_1.y;
val_h2_1[i] = __float22half2_rn(val_f2_1);
// }
// }
// output_f4[idx] = val_f4;
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
// output_f4[idx+1] = val_f4_1;
...
@@ -213,13 +215,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
...
@@ -213,13 +215,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars,
// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// __half *means, const __half *inp,
// const __half *scale, const __half
*bias,
// const __half *scale, const __half
// int hidden_size) {
//
*bias,
int hidden_size) {
// // step 0. compute local sum
// // step 0. compute local sum
// float l_sum = 0;
// float l_sum = 0;
// float l_square_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// float4 val_f4_1 = inp_f4[idx+1];
// float4 val_f4_2 = inp_f4[idx+2];
// float4 val_f4_2 = inp_f4[idx+2];
...
@@ -234,11 +237,12 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
...
@@ -234,11 +237,12 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + val_f2_2.y + val_f2_3.x + val_f2_3.y;
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x +
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y;
// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x *
// l_square_sum += val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x
// l_square_sum += val_f2_2.x * val_f2_2.x + val_f2_2.y * val_f2_2.y;
// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x +
// l_square_sum += val_f2_3.x * val_f2_3.x + val_f2_3.y * val_f2_3.y;
// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x +
// val_f2_3.y * val_f2_3.y;
// }
// }
// }
// }
...
@@ -260,7 +264,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
...
@@ -260,7 +264,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// // step 2. layer norm result
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4;
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// // load scale, bias, input
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
...
@@ -303,14 +308,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
...
@@ -303,14 +308,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
bias_f2_1.x;
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
+ bias_f2_1.y;
//
bias_f2_1.x;
val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// val_f2_2.x = (val_f2_2.x - s_mean) * s_var *
scale_f2_2.x + bias_f2_2.x;
//
+ bias_f2_1.y;
val_f2_2.x = (val_f2_2.x - s_mean) * s_var *
// val_f2_2.y = (val_f2_2.y - s_mean) * s_var
* scale_f2_2.y + bias_f2_2.y;
//
scale_f2_2.x + bias_f2_2.x;
val_f2_2.y = (val_f2_2.y - s_mean) * s_var
// val_f2_3.x = (val_f2_3.x - s_mean) *
s_var * scale_f2_3.x + bias_f2_3.x;
//
* scale_f2_2.y + bias_f2_2.y;
val_f2_3.x = (val_f2_3.x - s_mean) *
//
val_f2_3.y = (val_f2_3.y - s_mean) * s_var * scale
_f2_3.y
+ bias
_f2_3.y
;
//
s_var * scale_f2_3.x + bias_f2_3.x; val
_f2_3.y
= (val
_f2_3.y
- s_mean)
//
val_h2[i] = __float22half2_rn(val_f2);
//
* s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] =
// val_h2_1[i] = __float22half2_rn(val_f2_1);
//
__float22half2_rn(val_f2);
val_h2_1[i] = __float22half2_rn(val_f2_1);
// val_h2_2[i] = __float22half2_rn(val_f2_2);
// val_h2_2[i] = __float22half2_rn(val_f2_2);
// val_h2_3[i] = __float22half2_rn(val_f2_3);
// val_h2_3[i] = __float22half2_rn(val_f2_3);
// }
// }
...
@@ -414,11 +419,10 @@ means: [batch_size * seq_len], mean of ln forward,
...
@@ -414,11 +419,10 @@ means: [batch_size * seq_len], mean of ln forward,
(gamma && betta) ^ (vars && means) should be true
(gamma && betta) ^ (vars && means) should be true
*/
*/
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ker_ln_bw_dgamma_dbetta
(
T
*
gamma_grad
,
T
*
betta_grad
,
__global__
void
const
T
*
out_grad
,
const
T
*
inp_or_out
,
ker_ln_bw_dgamma_dbetta
(
T
*
gamma_grad
,
T
*
betta_grad
,
const
T
*
out_grad
,
const
T
*
gamma
,
const
T
*
betta
,
const
T
*
inp_or_out
,
const
T
*
gamma
,
const
T
*
betta
,
const
T
*
vars
,
const
T
*
means
,
int
rows
,
const
T
*
vars
,
const
T
*
means
,
int
rows
,
int
width
)
{
int
width
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
];
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
];
...
@@ -698,11 +702,10 @@ __global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad,
...
@@ -698,11 +702,10 @@ __global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad,
}
}
__global__
void
ker_ln_bw_dinp_x2
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
__global__
void
ker_ln_bw_dinp_x2
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
vars
,
const
__half
*
means
,
const
__half
*
means
,
int
hidden_dim
)
{
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
*
2
+
threadIdx
.
x
*
2
;
int
offset
=
blockIdx
.
x
*
hidden_dim
*
2
+
threadIdx
.
x
*
2
;
float2
dxhat
[
4
],
xhat
[
4
];
float2
dxhat
[
4
],
xhat
[
4
];
...
@@ -762,7 +765,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
...
@@ -762,7 +765,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
xhat
[
i
].
y
=
(
vout
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
xhat
[
i
].
y
=
(
vout
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
xhat_1
[
i
].
y
=
(
vout_1
.
y
-
vbetta_1
.
y
)
/
add_eps
(
vgamma_1
.
y
);
xhat_1
[
i
].
y
=
(
vout_1
.
y
-
vbetta_1
.
y
)
/
add_eps
(
vgamma_1
.
y
);
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
}
}
}
else
{
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
...
@@ -776,7 +780,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
...
@@ -776,7 +780,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
xhat
[
i
].
y
=
(
vinp
.
y
-
fmean
)
*
var_rsqrt
;
xhat
[
i
].
y
=
(
vinp
.
y
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
y
=
(
vinp_1
.
y
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
y
=
(
vinp_1
.
y
-
fmean
)
*
var_rsqrt
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
}
}
}
}
}
}
...
@@ -802,7 +807,7 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
...
@@ -802,7 +807,7 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
// Add the residual grad,
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
float4
dresidual_1
=
((
const
float4
*
)
residual_grad
)[
offset
+
1
];
float4
dresidual_1
=
((
const
float4
*
)
residual_grad
)[
offset
+
1
];
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
__half
*
hdres_1
=
reinterpret_cast
<
__half
*>
(
&
dresidual_1
);
__half
*
hdres_1
=
reinterpret_cast
<
__half
*>
(
&
dresidual_1
);
#pragma unroll
#pragma unroll
...
@@ -846,11 +851,10 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
...
@@ -846,11 +851,10 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
}
}
__global__
void
ker_ln_bw_dinp_x4
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
__global__
void
ker_ln_bw_dinp_x4
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
vars
,
const
__half
*
means
,
const
__half
*
means
,
int
hidden_dim
)
{
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
*
4
+
threadIdx
.
x
*
4
;
int
offset
=
blockIdx
.
x
*
hidden_dim
*
4
+
threadIdx
.
x
*
4
;
float2
dxhat
[
4
],
xhat
[
4
];
float2
dxhat
[
4
],
xhat
[
4
];
...
@@ -901,8 +905,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
...
@@ -901,8 +905,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
dxhat_2
[
i
].
y
=
vdout_2
.
y
*
vgamma_2
.
y
;
dxhat_2
[
i
].
y
=
vdout_2
.
y
*
vgamma_2
.
y
;
dxhat_3
[
i
].
x
=
vdout_3
.
x
*
vgamma_3
.
x
;
dxhat_3
[
i
].
x
=
vdout_3
.
x
*
vgamma_3
.
x
;
dxhat_3
[
i
].
y
=
vdout_3
.
y
*
vgamma_3
.
y
;
dxhat_3
[
i
].
y
=
vdout_3
.
y
*
vgamma_3
.
y
;
reduce_val
[
0
]
+=
dxhat
[
i
].
x
+
dxhat
[
i
].
y
+
dxhat_1
[
i
].
x
+
dxhat_1
[
i
].
y
+
dxhat_2
[
i
].
x
+
reduce_val
[
0
]
+=
dxhat
[
i
].
x
+
dxhat
[
i
].
y
+
dxhat_1
[
i
].
x
+
dxhat_1
[
i
].
y
+
dxhat_2
[
i
].
y
+
dxhat_3
[
i
].
x
+
dxhat_3
[
i
].
y
;
dxhat_2
[
i
].
x
+
dxhat_2
[
i
].
y
+
dxhat_3
[
i
].
x
+
dxhat_3
[
i
].
y
;
}
}
/*
/*
...
@@ -947,9 +952,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
...
@@ -947,9 +952,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
xhat_2
[
i
].
y
=
(
vout_2
.
y
-
vbetta_2
.
y
)
/
add_eps
(
vgamma_2
.
y
);
xhat_2
[
i
].
y
=
(
vout_2
.
y
-
vbetta_2
.
y
)
/
add_eps
(
vgamma_2
.
y
);
xhat_3
[
i
].
y
=
(
vout_3
.
y
-
vbetta_3
.
y
)
/
add_eps
(
vgamma_3
.
y
);
xhat_3
[
i
].
y
=
(
vout_3
.
y
-
vbetta_3
.
y
)
/
add_eps
(
vgamma_3
.
y
);
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
reduce_val
[
1
]
+=
xhat_2
[
i
].
x
*
dxhat_2
[
i
].
x
+
xhat_2
[
i
].
y
*
dxhat_2
[
i
].
y
;
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_3
[
i
].
x
*
dxhat_3
[
i
].
x
+
xhat_3
[
i
].
y
*
dxhat_3
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_2
[
i
].
x
*
dxhat_2
[
i
].
x
+
xhat_2
[
i
].
y
*
dxhat_2
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_3
[
i
].
x
*
dxhat_3
[
i
].
x
+
xhat_3
[
i
].
y
*
dxhat_3
[
i
].
y
;
}
}
}
else
{
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
...
@@ -969,9 +977,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
...
@@ -969,9 +977,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
xhat_2
[
i
].
y
=
(
vinp_2
.
y
-
fmean
)
*
var_rsqrt
;
xhat_2
[
i
].
y
=
(
vinp_2
.
y
-
fmean
)
*
var_rsqrt
;
xhat_3
[
i
].
y
=
(
vinp_3
.
y
-
fmean
)
*
var_rsqrt
;
xhat_3
[
i
].
y
=
(
vinp_3
.
y
-
fmean
)
*
var_rsqrt
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
reduce_val
[
1
]
+=
xhat_2
[
i
].
x
*
dxhat_2
[
i
].
x
+
xhat_2
[
i
].
y
*
dxhat_2
[
i
].
y
;
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_3
[
i
].
x
*
dxhat_3
[
i
].
x
+
xhat_3
[
i
].
y
*
dxhat_3
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_2
[
i
].
x
*
dxhat_2
[
i
].
x
+
xhat_2
[
i
].
y
*
dxhat_2
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_3
[
i
].
x
*
dxhat_3
[
i
].
x
+
xhat_3
[
i
].
y
*
dxhat_3
[
i
].
y
;
}
}
}
}
}
}
...
@@ -997,9 +1008,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
...
@@ -997,9 +1008,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
// Add the residual grad,
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
float4
dresidual_1
=
((
const
float4
*
)
residual_grad
)[
offset
+
1
];
float4
dresidual_1
=
((
const
float4
*
)
residual_grad
)[
offset
+
1
];
float4
dresidual_2
=
((
const
float4
*
)
residual_grad
)[
offset
+
2
];
float4
dresidual_2
=
((
const
float4
*
)
residual_grad
)[
offset
+
2
];
float4
dresidual_3
=
((
const
float4
*
)
residual_grad
)[
offset
+
3
];
float4
dresidual_3
=
((
const
float4
*
)
residual_grad
)[
offset
+
3
];
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
__half
*
hdres_1
=
reinterpret_cast
<
__half
*>
(
&
dresidual_1
);
__half
*
hdres_1
=
reinterpret_cast
<
__half
*>
(
&
dresidual_1
);
__half
*
hdres_2
=
reinterpret_cast
<
__half
*>
(
&
dresidual_2
);
__half
*
hdres_2
=
reinterpret_cast
<
__half
*>
(
&
dresidual_2
);
...
@@ -1139,22 +1150,21 @@ void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad,
...
@@ -1139,22 +1150,21 @@ void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad,
if
(
hidden_dim
*
8
<=
8192
)
{
if
(
hidden_dim
*
8
<=
8192
)
{
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
ker_ln_bw_dinp
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
hidden_dim
);
means
,
hidden_dim
);
}
else
if
(
hidden_dim
*
8
>
8192
&&
hidden_dim
*
8
<=
8192
*
2
)
{
}
else
if
(
hidden_dim
*
8
>
8192
&&
hidden_dim
*
8
<=
8192
*
2
)
{
hidden_dim
>>=
1
;
hidden_dim
>>=
1
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp_x2
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
ker_ln_bw_dinp_x2
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
hidden_dim
);
means
,
hidden_dim
);
}
else
if
(
hidden_dim
*
8
>
2
*
8192
&&
hidden_dim
*
8
<=
8192
*
4
)
{
}
else
if
(
hidden_dim
*
8
>
2
*
8192
&&
hidden_dim
*
8
<=
8192
*
4
)
{
hidden_dim
>>=
2
;
hidden_dim
>>=
2
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp_x4
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
ker_ln_bw_dinp_x4
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
hidden_dim
);
means
,
hidden_dim
);
}
else
{
}
else
{
throw
std
::
runtime_error
(
"hidden_dim % 4 != 0 || hidden_dim > 32768"
);
throw
std
::
runtime_error
(
"hidden_dim % 4 != 0 || hidden_dim > 32768"
);
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment