Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
038ed999
Commit
038ed999
authored
Jul 29, 2022
by
hubertlu-tw
Browse files
Fix some compiling errors
parent
bbf2c8d0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
36 deletions
+35
-36
apex/contrib/csrc/multihead_attn/layer_norm.cuh
apex/contrib/csrc/multihead_attn/layer_norm.cuh
+2
-2
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+33
-12
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+0
-22
No files found.
apex/contrib/csrc/multihead_attn/layer_norm.cuh
View file @
038ed999
...
...
@@ -261,7 +261,7 @@ cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean,
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
...
...
@@ -475,7 +475,7 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
T
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
038ed999
...
...
@@ -19,18 +19,13 @@ namespace multihead_attn {
namespace
self_bias_additive_mask
{
namespace
rocblas_gemmex
{
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_biases
,
torch
::
Tensor
const
&
output_biases
,
const
half
*
pad_mask
,
float
dropout_prob
)
{
const
half
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
...
...
@@ -49,6 +44,32 @@ namespace rocblas_gemmex {
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
bmm1_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
);
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
...
...
csrc/layer_norm_cuda_kernel.cu
View file @
038ed999
...
...
@@ -931,28 +931,6 @@ void HostApplyRMSNorm(
output
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
warp_size
);
}
template
<
typename
T
,
typename
U
,
typename
V
=
T
>
void
HostApplyRMSNorm
(
V
*
output
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
V
*
gamma
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyRMSNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
);
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment