Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c2b62b7f
Commit
c2b62b7f
authored
Mar 13, 2025
by
JR_ZZU
🌴
Browse files
delete origin files
parent
2a4864d5
Changes
164
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
5488 deletions
+0
-5488
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
+0
-135
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
+0
-25
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+0
-215
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
+0
-45
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
+0
-86
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+0
-1037
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
+0
-21
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
+0
-294
apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp
apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp
+0
-20
apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu
...ntrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu
+0
-228
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
+0
-36
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
...ntrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
+0
-506
apex/contrib/csrc/peer_memory/peer_memory.cpp
apex/contrib/csrc/peer_memory/peer_memory.cpp
+0
-29
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+0
-750
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+0
-50
apex/contrib/csrc/transducer/transducer_joint.cpp
apex/contrib/csrc/transducer/transducer_joint.cpp
+0
-98
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
+0
-985
apex/contrib/csrc/transducer/transducer_loss.cpp
apex/contrib/csrc/transducer/transducer_loss.cpp
+0
-109
apex/contrib/csrc/transducer/transducer_loss_kernel.cu
apex/contrib/csrc/transducer/transducer_loss_kernel.cu
+0
-767
apex/contrib/csrc/xentropy/interface.cpp
apex/contrib/csrc/xentropy/interface.cpp
+0
-52
No files found.
Too many changes to show.
To preserve performance only
164 of 164+
files are displayed.
Plain diff
Email patch
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
deleted
100644 → 0
View file @
2a4864d5
#pragma once
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
//#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h"
//#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
/*
rocblas_datatype a_type = rocblas_datatype_f16_r; // OK
rocblas_datatype b_type = rocblas_datatype_f16_r; // OK
rocblas_datatype c_type = rocblas_datatype_f16_r; // OK
rocblas_datatype d_type = rocblas_datatype_f16_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
int32_t solution_index = 0;
rocblas_int flags = 0;
*/
namespace
{
cublasOperation_t
convertTransToCublasOperation
(
char
trans
)
{
if
(
trans
==
't'
)
return
CUBLAS_OP_T
;
else
if
(
trans
==
'n'
)
return
CUBLAS_OP_N
;
else
if
(
trans
==
'c'
)
return
CUBLAS_OP_C
;
else
{
AT_ERROR
(
"trans must be one of: t, n, c"
);
return
CUBLAS_OP_T
;
}
}
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
,
rocblas_int
flags
)
{
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cublasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_strided_batched_ex
(
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
(
int
)
batchCount
,
rocblas_datatype_f32_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
));
}
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_int
flags
)
{
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
}
else
{
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
}
}
void
adjustLdLevel3
(
char
transa
,
char
transb
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
*
lda
,
int64_t
*
ldb
,
int64_t
*
ldc
)
{
int
transa_
=
((
transa
==
't'
)
||
(
transa
==
'T'
));
int
transb_
=
((
transb
==
't'
)
||
(
transb
==
'T'
));
// Note: leading dimensions generally are checked that they are > 0 and at
// least as big the result requires (even if the value won't be used).
if
(
n
<=
1
)
*
ldc
=
std
::
max
<
int64_t
>
(
m
,
1
);
if
(
transa_
)
{
if
(
m
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
else
{
if
(
k
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
m
,
1
);
}
if
(
transb_
)
{
if
(
k
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
n
,
1
);
}
else
{
if
(
n
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
}
void
HgemmStridedBatched
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
)
{
if
((
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
))
{
AT_ERROR
(
"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
"batchCount"
"with the bound [val] <= %d"
,
INT_MAX
);
}
adjustLdLevel3
(
transa
,
transb
,
m
,
n
,
k
,
&
lda
,
&
ldb
,
&
ldc
);
// gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
// b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
0
/*flags*/
);
}
}
// namespace
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
deleted
100644 → 0
View file @
2a4864d5
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nccl_p2p_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_unique_nccl_id"
,
&
apex
::
contrib
::
nccl_p2p
::
get_unique_nccl_id
,
"get_unique_nccl_id"
);
m
.
def
(
"init_nccl_comm"
,
&
apex
::
contrib
::
nccl_p2p
::
init_nccl_comm
,
"init_nccl_comm"
);
m
.
def
(
"left_right_halo_exchange_inplace"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange_inplace
,
"left_right_halo_exchange_inplace"
);
m
.
def
(
"left_right_halo_exchange"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange
,
"left_right_halo_exchange"
);
m
.
def
(
"add_delay"
,
&
apex
::
contrib
::
nccl_p2p
::
add_delay
,
"add_delay"
);
}
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <ctime>
#include <cassert>
#ifdef __HIP_PLATFORM_HCC__
#include "rccl/rccl.h"
#else
#include "nccl.h"
#endif
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* on the same machine using cudaMemcpyAsync peer-to-peer transfers.
*/
namespace
{
__global__
void
AddDelay_kernel
(
const
int
delay
,
int
*
counter
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int
new_counter
=
0
;
double
elapsed
=
0
;
clock_t
start
=
clock
();
do
{
clock_t
now
=
clock
();
elapsed
=
(
double
)(
now
-
start
)
*
1e9
/
CLOCKS_PER_SEC
;
++
new_counter
;
}
while
(
elapsed
<
(
double
)
delay
);
*
counter
=
new_counter
;
}
}
class
NcclCommWrapper
{
private:
ncclComm_t
comm
;
int
rank
,
world_size
;
ncclDataType_t
get_nccl_type
(
at
::
Tensor
input
)
{
switch
(
input
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
return
ncclFloat16
;
case
at
::
ScalarType
::
Float
:
return
ncclFloat32
;
case
at
::
ScalarType
::
Double
:
return
ncclFloat64
;
case
at
::
ScalarType
::
Byte
:
return
ncclUint8
;
case
at
::
ScalarType
::
Char
:
return
ncclInt8
;
case
at
::
ScalarType
::
Int
:
return
ncclInt32
;
case
at
::
ScalarType
::
Long
:
return
ncclInt64
;
case
at
::
ScalarType
::
BFloat16
:
return
ncclBfloat16
;
default:
assert
(
false
);
}
}
public:
NcclCommWrapper
()
{
memset
(
&
comm
,
0
,
sizeof
(
ncclComm_t
));
rank
=
0
;
world_size
=
0
;
}
NcclCommWrapper
(
ncclUniqueId
id
,
int
my_rank
,
int
num_ranks
)
{
ncclCommInitRank
(
&
comm
,
num_ranks
,
id
,
my_rank
);
rank
=
my_rank
;
world_size
=
num_ranks
;
}
~
NcclCommWrapper
()
{
printf
(
"ncclCommDestroy()
\n
"
);
ncclCommDestroy
(
comm
);
}
void
left_right_halo_exchange_inplace
(
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
ncclGroupStart
();
ncclDataType_t
ncclType
=
get_nccl_type
(
left_output_halo
);
bool
left_zero
=
(
left_rank
<
0
);
bool
right_zero
=
(
right_rank
<
0
);
size_t
left_n
=
torch
::
numel
(
left_output_halo
);
size_t
right_n
=
torch
::
numel
(
right_output_halo
);
assert
(
left_n
>
0
&&
left_n
==
right_n
);
if
(
left_zero
)
{
left_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
left_output_halo
.
scalar_type
(),
"left_halo_exch"
,
[
&
]()
{
// send left (to my_rank - 1)
ncclSend
(
left_output_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
left_rank
,
comm
,
stream
);
// receive left (from my_rank - 1)
ncclRecv
(
left_input_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
left_rank
,
comm
,
stream
);
});
}
if
(
right_zero
)
{
right_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
right_output_halo
.
scalar_type
(),
"right_halo_exch"
,
[
&
]()
{
// send right (to my_rank + 1 )
ncclSend
(
right_output_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
right_rank
,
comm
,
stream
);
// receive right (from my_rank + 1)
ncclRecv
(
right_input_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
right_rank
,
comm
,
stream
);
});
}
ncclGroupEnd
();
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
left_right_halo_exchange_inplace
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
return
{
left_input_halo
,
right_input_halo
};
}
};
class
ManagedObjects
{
public:
ManagedObjects
()
{
}
~
ManagedObjects
()
{
for
(
auto
it
=
_nccl_comms
.
begin
();
it
!=
_nccl_comms
.
end
();
++
it
)
{
delete
*
it
;
}
}
int
add_comm
(
NcclCommWrapper
*
comm
)
{
int
handle
=
_nccl_comms
.
size
();
_nccl_comms
.
push_back
(
comm
);
return
handle
;
}
NcclCommWrapper
&
get_comm
(
int
handle
)
{
assert
(
handle
>=
0
&&
handle
<
_nccl_comms
.
size
());
return
*
_nccl_comms
[
handle
];
}
private:
std
::
vector
<
NcclCommWrapper
*>
_nccl_comms
;
};
class
ManagedObjects
mo
;
}
// end anonymous namespace
namespace
apex
{
namespace
contrib
{
namespace
nccl_p2p
{
at
::
Tensor
get_unique_nccl_id
(
int
n
)
{
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
auto
id_tensor
=
torch
::
empty
({
n
,(
int
)
sizeof
(
ncclUniqueId
)},
torch
::
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
).
requires_grad
(
false
));
auto
id_ptr
=
id_tensor
.
data_ptr
<
uint8_t
>
();
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
memcpy
(
id_ptr
+
offset
,
&
id
,
sizeof
(
ncclUniqueId
));
offset
+=
sizeof
(
ncclUniqueId
);
}
return
id_tensor
;
}
int
init_nccl_comm
(
at
::
Tensor
unique_nccl_id
,
int
my_rank
,
int
num_ranks
)
{
ncclUniqueId
id
;
auto
unique_nccl_id_ptr
=
unique_nccl_id
.
data_ptr
<
uint8_t
>
();
memcpy
(
&
id
,
unique_nccl_id_ptr
,
sizeof
(
ncclUniqueId
));
NcclCommWrapper
*
comm
=
new
NcclCommWrapper
(
id
,
my_rank
,
num_ranks
);
int
handle
=
mo
.
add_comm
(
comm
);
comm
=
0L
;
return
handle
;
}
void
left_right_halo_exchange_inplace
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
return
communicator
.
left_right_halo_exchange_inplace
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
return
communicator
.
left_right_halo_exchange
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
);
}
void
add_delay
(
int
delay
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
t
=
torch
::
empty
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
AddDelay_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
delay
,
t
.
data_ptr
<
int
>
());
}
}}}
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
deleted
100644 → 0
View file @
2a4864d5
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _nccl_p2p_h_
#define _nccl_p2p_h_
namespace
apex
{
namespace
contrib
{
namespace
nccl_p2p
{
at
::
Tensor
get_unique_nccl_id
(
int
n
);
int
init_nccl_comm
(
at
::
Tensor
unique_nccl_id
,
int
my_rank
,
int
num_ranks
);
void
left_right_halo_exchange_inplace
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
);
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
);
void
add_delay
(
int
delay
);
}}}
#endif
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <torch/extension.h>
// CUDA forward declaration
void
fused_strided_check_finite
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
);
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_reversible_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_maybe_adam_undo_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
maybe_cast_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
);
void
maybe_cast_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void
strided_check_finite
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
)
{
CHECK_INPUT
(
p_copy
);
fused_strided_check_finite
(
overflow_flag
,
p_copy
,
stride
,
clear_overflow_first
);
}
void
adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
void
reversible_adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
fused_reversible_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
void
maybe_adam_undo
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
fused_maybe_adam_undo_cuda
(
overflow_flag
,
p
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
void
maybe_cast
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
)
{
CHECK_INPUT
(
p_in
);
CHECK_INPUT
(
p_out
);
int64_t
num_elem
=
p_in
.
numel
();
AT_ASSERTM
(
p_out
.
numel
()
==
num_elem
,
"number of elements in p_in and p_out should be equal"
);
maybe_cast_cuda
(
overflow_flag
,
p_in
,
p_out
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"reversible_adam"
,
&
reversible_adam
,
"Reversible Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_adam_undo"
,
&
maybe_adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_cast"
,
&
maybe_cast
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"maybe_cast_mt"
,
&
maybe_cast_cuda_mt
,
"Unpack byte tensor containing e5m2 floats."
);
}
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
deleted
100644 → 0
View file @
2a4864d5
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include "ATen/TensorUtils.h"
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
){
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
template
<
typename
T
>
__device__
__forceinline__
void
load_store
(
T
*
dst
,
T
*
src
,
int
dst_offset
,
int
src_offset
){
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LT
;
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
}
#include "type_shim.h"
typedef
enum
{
ADAM_MODE_0
=
0
,
// eps under square root
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
adam_cuda_kernel
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
p
[
j
]
=
p
[
j
]
-
(
step_size
*
update
);
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
DEPTH
>&
tl
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
adamMode_t
mode
,
const
float
decay
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
p_copy
=
NULL
;
if
(
DEPTH
==
5
)
{
p_copy
=
(
GRAD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
p_copy
+=
chunk_idx
*
chunk_size
;
}
n
-=
chunk_idx
*
chunk_size
;
T
incoming_p
[
ILP
];
T
incoming_m
[
ILP
];
T
incoming_v
[
ILP
];
T
incoming_g
[
ILP
];
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
p
)
&&
is_aligned
(
m
)
&&
is_aligned
(
v
)
&&
is_aligned
(
g
)
&&
is_aligned
(
p_copy
))
{
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
GRAD_T
tmp_g
[
ILP
];
load_store
(
incoming_p
,
p
,
0
,
i_start
);
load_store
(
incoming_m
,
m
,
0
,
i_start
);
load_store
(
incoming_v
,
v
,
0
,
i_start
);
load_store
(
tmp_g
,
g
,
0
,
i_start
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_g
[
ii
]
=
static_cast
<
T
>
(
tmp_g
[
ii
]);
T
scaled_grad
=
incoming_g
[
ii
]
/
grad_scale
;
incoming_m
[
ii
]
=
b1
*
incoming_m
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
incoming_v
[
ii
]
=
b2
*
incoming_v
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
incoming_v
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
incoming_v
[
ii
])
+
eps
;
float
update
=
(
incoming_m
[
ii
]
/
denom
)
+
(
decay
*
incoming_p
[
ii
]);
incoming_p
[
ii
]
=
incoming_p
[
ii
]
-
(
step_size
*
update
);
if
(
DEPTH
==
5
)
tmp_g
[
ii
]
=
static_cast
<
GRAD_T
>
(
incoming_p
[
ii
]);
}
load_store
(
p
,
incoming_p
,
i_start
,
0
);
load_store
(
m
,
incoming_m
,
i_start
,
0
);
load_store
(
v
,
incoming_v
,
i_start
,
0
);
if
(
DEPTH
==
5
)
load_store
(
p_copy
,
tmp_g
,
i_start
,
0
);
}
}
else
{
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_p
[
ii
]
=
0
;
incoming_m
[
ii
]
=
0
;
incoming_v
[
ii
]
=
0
;
incoming_g
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
incoming_p
[
ii
]
=
p
[
i
];
incoming_m
[
ii
]
=
m
[
i
];
incoming_v
[
ii
]
=
v
[
i
];
incoming_g
[
ii
]
=
static_cast
<
T
>
(
g
[
i
]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
n
&&
j
<
chunk_size
)
{
T
scaled_grad
=
incoming_g
[
ii
]
/
grad_scale
;
m
[
j
]
=
b1
*
incoming_m
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
incoming_v
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
incoming_p
[
ii
]);
p
[
j
]
=
incoming_p
[
ii
]
-
(
step_size
*
update
);
if
(
DEPTH
==
5
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
}
}
}
};
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
// using namespace at;
//Get tensor size
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
g
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
C10_CUDA_CHECK
(
cudaGetLastError
());
}
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
4
||
tl_sz
==
5
,
"expected tensor lists of size 4 or 5"
);
if
(
tensor_lists
[
3
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Half
||
tensor_lists
[
3
][
0
].
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
//alher values should be fp32 for half gradients
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
if
(
tl_sz
==
5
)
{
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
else
{
if
(
tl_sz
==
5
)
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_cuda_mt_kernel"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
C10_CUDA_CHECK
(
cudaGetLastError
());
}
template
<
typename
FROM_T
,
typename
TO_T
>
__device__
void
convert
(
const
FROM_T
vi
,
TO_T
&
vo
)
{
vo
=
static_cast
<
TO_T
>
(
vi
);
}
template
<
>
__device__
void
convert
(
const
float
vi
,
uint8_t
&
vo
)
{
union
S
{
float
as_float
;
int
as_int
;
};
S
s
;
s
.
as_float
=
vi
;
s
.
as_int
=
s
.
as_int
&
0xFF800000
;
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_half
=
static_cast
<
at
::
Half
>
(
vi
+
s
.
as_float
/
8.0
f
);
vo
=
t
.
as_byte
[
1
];
}
template
<
>
__device__
void
convert
(
const
uint8_t
vi
,
float
&
vo
)
{
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_byte
[
0
]
=
0
;
t
.
as_byte
[
1
]
=
vi
;
vo
=
static_cast
<
float
>
(
t
.
as_half
);
}
template
<
>
__device__
void
convert
(
const
at
::
Half
vi
,
uint8_t
&
vo
)
{
union
S
{
float
as_float
;
int
as_int
;
};
S
s
;
s
.
as_float
=
static_cast
<
float
>
(
vi
);
s
.
as_int
=
s
.
as_int
&
0xFF800000
;
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_half
=
static_cast
<
at
::
Half
>
(
vi
+
s
.
as_float
/
8.0
f
);
vo
=
t
.
as_byte
[
1
];
}
template
<
>
__device__
void
convert
(
const
uint8_t
vi
,
at
::
Half
&
vo
)
{
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_byte
[
0
]
=
0
;
t
.
as_byte
[
1
]
=
vi
;
vo
=
t
.
as_half
;
}
template
<
typename
GRAD_T
>
__global__
void
strided_check_finite_cuda_kernel
(
volatile
int
*
noop_gmem
,
GRAD_T
*
__restrict__
p_copy
,
const
size_t
tsize
,
int
stride
,
int
clear_overflow_first
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
)
*
stride
;
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
*
stride
;
if
(
clear_overflow_first
)
{
if
(
i
==
0
)
{
*
noop_gmem
=
0
;
}
__syncthreads
();
}
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
GRAD_T
pi
=
p_copy
[
j
];
if
(
!
isfinite
(
pi
))
{
*
noop_gmem
=
1
;
}
}
}
template
<
>
__global__
void
strided_check_finite_cuda_kernel
(
volatile
int
*
noop_gmem
,
uint8_t
*
__restrict__
p_copy
,
const
size_t
tsize
,
int
stride
,
int
clear_overflow_first
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
)
*
stride
;
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
*
stride
;
if
(
clear_overflow_first
)
{
if
(
i
==
0
)
{
*
noop_gmem
=
0
;
}
__syncthreads
();
}
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
at
::
Half
pi
;
convert
(
p_copy
[
j
],
pi
);
if
(
!
isfinite
(
pi
))
{
*
noop_gmem
=
1
;
}
}
}
template
<
typename
FROM_T
,
typename
TO_T
>
__global__
void
maybe_cast_kernel
(
volatile
int
*
overflow_flag
,
const
FROM_T
*
p_in
,
TO_T
*
p_out
,
const
size_t
tsize
)
{
if
(
overflow_flag
&&
*
overflow_flag
!=
0
)
return
;
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
FROM_T
pi
[
ILP
];
TO_T
po
[
ILP
];
for
(
int
j_start
=
0
;
j_start
<
tsize
;
j_start
+=
totThreads
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
pi
[
ii
]
=
0
;
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p_in
[
j
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
convert
(
pi
[
ii
],
po
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
p_out
[
j
]
=
po
[
ii
];
}
}
}
}
template
<
typename
T
,
typename
GRAD_T
,
typename
REDU_T
>
__global__
void
reversible_adam_cuda_kernel
(
T
*
__restrict__
p
,
REDU_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
mi
[
ILP
];
T
vi
[
ILP
];
T
pi
[
ILP
];
T
gi
[
ILP
];
bool
overflow
=
false
;
for
(
int
j_start
=
0
;
j_start
<
tsize
;
j_start
+=
totThreads
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
mi
[
ii
]
=
T
(
0
);
vi
[
ii
]
=
T
(
0
);
pi
[
ii
]
=
T
(
0
);
gi
[
ii
]
=
GRAD_T
(
0
);
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p
[
j
];
mi
[
ii
]
=
m
[
j
];
vi
[
ii
]
=
v
[
j
];
gi
[
ii
]
=
static_cast
<
T
>
(
g
[
j
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
T
scaled_grad
=
gi
[
ii
]
/
grad_scale
;
if
(
isfinite
(
scaled_grad
))
{
mi
[
ii
]
=
b1
*
mi
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
vi
[
ii
]
=
b2
*
vi
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
vi
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
vi
[
ii
])
+
eps
;
float
update
=
(
mi
[
ii
]
/
denom
)
+
(
decay
*
pi
[
ii
]);
pi
[
ii
]
=
pi
[
ii
]
-
(
step_size
*
update
);
}
else
{
overflow
=
true
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
m
[
j
]
=
mi
[
ii
];
v
[
j
]
=
vi
[
ii
];
p
[
j
]
=
pi
[
ii
];
if
(
p_copy
!=
NULL
)
{
convert
(
pi
[
ii
],
p_copy
[
j
]);
}
}
}
}
if
(
p_copy
!=
NULL
)
{
__syncthreads
();
if
(
overflow
)
{
convert
(
float
(
INFINITY
),
p_copy
[
0
]);
}
}
}
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
maybe_adam_undo_cuda_kernel
(
volatile
int
*
overflow_flag
,
T
*
__restrict__
p
,
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
)
{
// NB! Skip undo kernel when overflow flag is NOT set
if
(
overflow_flag
&&
*
overflow_flag
==
0
)
return
;
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
mi
[
ILP
];
T
vi
[
ILP
];
T
pi
[
ILP
];
T
gi
[
ILP
];
for
(
int
j_start
=
0
;
j_start
<
tsize
;
j_start
+=
totThreads
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
mi
[
ii
]
=
T
(
0
);
vi
[
ii
]
=
T
(
0
);
pi
[
ii
]
=
T
(
0
);
gi
[
ii
]
=
GRAD_T
(
0
);
int
j
=
j_start
+
i
*
ILP
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p
[
j
];
mi
[
ii
]
=
m
[
j
];
vi
[
ii
]
=
v
[
j
];
gi
[
ii
]
=
static_cast
<
T
>
(
g
[
j
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
T
scaled_grad
=
gi
[
ii
]
/
grad_scale
;
if
(
isfinite
(
scaled_grad
))
{
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
vi
[
ii
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
vi
[
ii
])
+
eps
;
pi
[
ii
]
=
(
pi
[
ii
]
+
step_size
*
(
mi
[
ii
]
/
denom
))
/
(
1.0
f
-
step_size
*
decay
);
mi
[
ii
]
=
(
mi
[
ii
]
-
(
1
-
b1
)
*
scaled_grad
)
/
b1
;
vi
[
ii
]
=
(
vi
[
ii
]
-
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
)
/
b2
;
// Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step.
vi
[
ii
]
=
vi
[
ii
]
>=
0.0
f
?
vi
[
ii
]
:
0.0
f
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
*
ILP
;
if
(
j
<
tsize
)
{
m
[
j
]
=
mi
[
ii
];
v
[
j
]
=
vi
[
ii
];
p
[
j
]
=
pi
[
ii
];
}
}
}
}
template
<
int
DEPTH
,
typename
FROM_T
,
typename
TO_T
>
struct
MaybeCastFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
overflow_flag
,
TensorListMetadata
<
DEPTH
>&
tl
)
{
if
(
overflow_flag
&&
*
overflow_flag
!=
0
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
FROM_T
*
p_in
=
(
FROM_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p_in
+=
chunk_idx
*
chunk_size
;
TO_T
*
p_out
=
(
TO_T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p_out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
int
dim
=
chunk_size
<
n
?
chunk_size
:
n
;
FROM_T
pi
[
ILP
];
TO_T
po
[
ILP
];
for
(
int
j_start
=
0
;
j_start
<
dim
;
j_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
pi
[
ii
]
=
FROM_T
(
0
);
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
pi
[
ii
]
=
p_in
[
j
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
convert
(
pi
[
ii
],
po
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
p_out
[
j
]
=
po
[
ii
];
}
}
}
}
};
void
fused_strided_check_finite
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
)
{
//Get tensor size
int
tsize
=
p_copy
.
numel
();
int
niter
=
(
tsize
+
stride
-
1
)
/
stride
;
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
//In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set.
const
dim3
blocks
(
clear_overflow_first
?
1
:
(
niter
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p_copy
),
"parameter tensor is too large to be indexed with int32"
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_HALF_AND_BYTE
(
p_copy
.
scalar_type
(),
0
,
"check_finite_cuda_kernel"
,
strided_check_finite_cuda_kernel
<
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
overflow_flag
.
DATA_PTR
<
int
>
(),
p_copy
.
DATA_PTR
<
scalar_t_0
>
(),
tsize
,
stride
,
clear_overflow_first
);
);
C10_CUDA_CHECK
(
cudaGetLastError
());
}
void
fused_reversible_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
// using namespace at;
//Get tensor size
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
g
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
if
(
p_copy
.
numel
()
==
0
||
p_copy
.
scalar_type
()
==
g
.
scalar_type
())
{
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
reversible_adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
AT_ASSERTM
(
p_copy
.
scalar_type
()
==
at
::
ScalarType
::
Byte
,
"expected parameter to be of byte type"
);
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16
(
g
.
scalar_type
(),
0
,
"adam_cuda_e5m2_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
reversible_adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
uint8_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
DATA_PTR
<
uint8_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
else
{
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
reversible_adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
C10_CUDA_CHECK
(
cudaGetLastError
());
}
void
maybe_cast_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
)
{
//Get tensor size
int
tsize
=
p_in
.
numel
();
AT_ASSERTM
(
tsize
==
p_out
.
numel
(),
"p_in.numel() must equal p_out.numel()"
);
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p_in
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_FLOAT_HALF_AND_BYTE
(
p_in
.
scalar_type
(),
0
,
"maybe_cast_cuda"
DISPATCH_FLOAT_HALF_AND_BYTE
(
p_out
.
scalar_type
(),
1
,
"maybe_cast_cuda"
,
maybe_cast_kernel
<
scalar_t_0
,
scalar_t_1
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
overflow_flag
.
numel
()
?
overflow_flag
.
DATA_PTR
<
int
>
()
:
NULL
,
p_in
.
DATA_PTR
<
scalar_t_0
>
(),
p_out
.
DATA_PTR
<
scalar_t_1
>
(),
tsize
);
))
C10_CUDA_CHECK
(
cudaGetLastError
());
}
void
maybe_cast_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
)
// p_in, p_out
{
//Constants
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
2
,
"expected tensor lists of size 2"
);
DISPATCH_FLOAT_HALF_AND_BYTE
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"maybe_cast_cuda_mt_kernel"
,
DISPATCH_FLOAT_HALF_AND_BYTE
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"maybe_cast_cuda_mt_kernel"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
overflow_flag
,
tensor_lists
,
MaybeCastFunctor
<
2
,
scalar_t_0
,
scalar_t_1
>
());
))
C10_CUDA_CHECK
(
cudaGetLastError
());
}
void
fused_maybe_adam_undo_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
//Get tensor size
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
g
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
maybe_adam_undo_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
overflow_flag
.
numel
()
?
overflow_flag
.
DATA_PTR
<
int
>
()
:
NULL
,
p
.
DATA_PTR
<
accscalar_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
maybe_adam_undo_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
overflow_flag
.
numel
()
?
overflow_flag
.
DATA_PTR
<
int
>
()
:
NULL
,
p
.
DATA_PTR
<
scalar_t_0
>
(),
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
C10_CUDA_CHECK
(
cudaGetLastError
());
}
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <torch/extension.h>
void
multi_tensor_lamb_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
const
float
global_grad_norm
,
const
float
max_grad_norm
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"lamb"
,
&
multi_tensor_lamb_cuda
,
"Computes and apply update for LAMB optimizer"
);
}
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
deleted
100644 → 0
View file @
2a4864d5
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
typedef
enum
{
MOMENT_MODE_0
=
0
,
// L2 regularization mode
MOMENT_MODE_1
=
1
// Decoupled weight decay mode
}
adamMode_t
;
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
);
using
MATH_T
=
float
;
template
<
typename
T
>
struct
LAMBStage1Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>&
tl
,
const
float
beta1
,
const
float
beta2
,
const
float
beta3
,
const
float
beta1_correction
,
const
float
beta2_correction
,
const
float
epsilon
,
adamMode_t
mode
,
const
float
decay
,
const
float
global_grad_norm
,
const
float
max_global_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
clipped_global_grad_norm
=
global_grad_norm
>
max_global_grad_norm
?
global_grad_norm
/
max_global_grad_norm
:
1.0
f
;
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
// see note in multi_tensor_scale_kernel.cu
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_g
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_m
[
ILP
];
MATH_T
r_v
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_g
[
ii
]
=
g
[
i
];
// special ?optimization? for lamb stage 1
if
(
decay
==
0
)
{
r_p
[
ii
]
=
MATH_T
(
0
);
}
else
{
r_p
[
ii
]
=
p
[
i
];
}
r_m
[
ii
]
=
m
[
i
];
r_v
[
ii
]
=
v
[
i
];
}
else
{
r_g
[
ii
]
=
MATH_T
(
0
);
r_p
[
ii
]
=
MATH_T
(
0
);
r_m
[
ii
]
=
MATH_T
(
0
);
r_v
[
ii
]
=
MATH_T
(
0
);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
clipped_global_grad_norm
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
clipped_global_grad_norm
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
r_p
[
ii
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
g
[
i
]
=
r_p
[
ii
];
m
[
i
]
=
r_m
[
ii
];
v
[
i
]
=
r_v
[
ii
];
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template
<
typename
T
>
struct
LAMBStage2Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
,
const
float
decay
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
MATH_T
ratio
=
learning_rate
;
// apply adaptive learning rate to parameters with non-zero weight decay
if
(
decay
!=
0.0
)
{
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
T
*
update
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_p
[
ILP
];
MATH_T
r_update
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_p
[
ii
]
=
p
[
i
];
r_update
[
ii
]
=
update
[
i
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
r_p
[
ii
]
-
(
ratio
*
r_update
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
p
[
i
]
=
r_p
[
ii
];
}
}
}
}
};
void
multi_tensor_lamb_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
const
float
global_grad_norm
,
const
float
max_grad_norm
)
{
using
namespace
at
;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
}
// Handle grad averaging mode
float
beta3
=
1.0
f
;
if
(
grad_averaging
==
1
)
beta3
=
1
-
beta1
;
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
param_list
(
tensor_lists
.
begin
()
+
1
,
tensor_lists
.
begin
()
+
2
);
// Compute per tensor param norm
auto
param_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
param_list
,
true
);
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
LAMBStage1Functor
<
scalar_t_0
>
(),
beta1
,
beta2
,
beta3
,
// 1-beta1 or 1 depends on averaging mode
bias_correction1
,
bias_correction2
,
epsilon
,
(
adamMode_t
)
mode
,
weight_decay
,
global_grad_norm
,
max_grad_norm
);
)
// Compute update norms
auto
update_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
true
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_param_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
2
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
grad_param_list
,
LAMBStage2Functor
<
scalar_t_0
>
(),
std
::
get
<
1
>
(
param_norm_tuple
).
DATA_PTR
<
float
>
(),
std
::
get
<
1
>
(
update_norm_tuple
).
DATA_PTR
<
float
>
(),
lr
,
weight_decay
);
)
AT_CUDA_CHECK
(
cudaGetLastError
());
}
apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <torch/extension.h>
void
multi_tensor_fused_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_beta1
,
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_bias_correction
,
at
::
Tensor
per_tensor_eps
,
at
::
Tensor
per_tensor_weight_decay
,
float
lr
,
float
grad_scale
,
int
step
,
int
mode
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_fused_adam"
,
&
multi_tensor_fused_adam_cuda
,
"Multi tensor Adam optimized CUDA implementation."
);
}
apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu
deleted
100644 → 0
View file @
2a4864d5
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cmath>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
){
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
template
<
typename
T
>
__device__
__forceinline__
void
load_store
(
T
*
dst
,
T
*
src
,
int
dst_offset
,
int
src_offset
){
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LT
;
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
}
typedef
enum
{
ADAM_MODE_0
=
0
,
// eps under square root
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
DistAdamFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
DEPTH
>&
tl
,
const
float
*
per_tensor_beta1
,
const
float
*
per_tensor_beta2
,
const
int
*
per_tensor_bias_correction
,
const
float
*
per_tensor_eps
,
const
float
*
per_tensor_weight_decay
,
const
float
lr
,
const
float
grad_scale
,
const
int
step
,
adamMode_t
mode
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
b1
=
per_tensor_beta1
[
tensor_num
];
float
b2
=
per_tensor_beta2
[
tensor_num
];
float
eps
=
per_tensor_eps
[
tensor_num
];
float
decay
=
per_tensor_weight_decay
[
tensor_num
];
float
beta1_correction
=
1.0
f
,
beta2_correction
=
1.0
f
;
if
(
per_tensor_bias_correction
[
tensor_num
]
==
1
)
{
beta1_correction
=
1
-
std
::
pow
(
b1
,
step
);
beta2_correction
=
1
-
std
::
pow
(
b2
,
step
);
}
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
p_copy
=
NULL
;
if
(
DEPTH
==
5
)
{
p_copy
=
(
GRAD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
p_copy
+=
chunk_idx
*
chunk_size
;
}
n
-=
chunk_idx
*
chunk_size
;
T
incoming_p
[
ILP
];
T
incoming_m
[
ILP
];
T
incoming_v
[
ILP
];
T
incoming_g
[
ILP
];
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
p
)
&&
is_aligned
(
m
)
&&
is_aligned
(
v
)
&&
is_aligned
(
g
)
&&
is_aligned
(
p_copy
))
{
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
GRAD_T
tmp_g
[
ILP
];
load_store
(
incoming_p
,
p
,
0
,
i_start
);
load_store
(
incoming_m
,
m
,
0
,
i_start
);
load_store
(
incoming_v
,
v
,
0
,
i_start
);
load_store
(
tmp_g
,
g
,
0
,
i_start
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_g
[
ii
]
=
static_cast
<
T
>
(
tmp_g
[
ii
]);
T
scaled_grad
=
incoming_g
[
ii
]
/
grad_scale
;
incoming_m
[
ii
]
=
b1
*
incoming_m
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
incoming_v
[
ii
]
=
b2
*
incoming_v
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
T
next_m_unbiased
=
incoming_m
[
ii
]
/
beta1_correction
;
T
next_v_unbiased
=
incoming_v
[
ii
]
/
beta2_correction
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
next_v_unbiased
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
next_v_unbiased
)
+
eps
;
float
update
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
incoming_p
[
ii
]);
incoming_p
[
ii
]
=
incoming_p
[
ii
]
-
(
lr
*
update
);
if
(
DEPTH
==
5
)
tmp_g
[
ii
]
=
static_cast
<
GRAD_T
>
(
incoming_p
[
ii
]);
}
load_store
(
p
,
incoming_p
,
i_start
,
0
);
load_store
(
m
,
incoming_m
,
i_start
,
0
);
load_store
(
v
,
incoming_v
,
i_start
,
0
);
if
(
DEPTH
==
5
)
load_store
(
p_copy
,
tmp_g
,
i_start
,
0
);
}
}
else
{
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_p
[
ii
]
=
0
;
incoming_m
[
ii
]
=
0
;
incoming_v
[
ii
]
=
0
;
incoming_g
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
incoming_p
[
ii
]
=
p
[
i
];
incoming_m
[
ii
]
=
m
[
i
];
incoming_v
[
ii
]
=
v
[
i
];
incoming_g
[
ii
]
=
static_cast
<
T
>
(
g
[
i
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
n
&&
j
<
chunk_size
)
{
T
scaled_grad
=
incoming_g
[
ii
]
/
grad_scale
;
m
[
j
]
=
b1
*
incoming_m
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
incoming_v
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
T
next_m_unbiased
=
m
[
j
]
/
beta1_correction
;
T
next_v_unbiased
=
v
[
j
]
/
beta2_correction
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
next_v_unbiased
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
next_v_unbiased
)
+
eps
;
float
update
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
incoming_p
[
ii
]);
p
[
j
]
=
incoming_p
[
ii
]
-
(
lr
*
update
);
if
(
DEPTH
==
5
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
}
}
}
};
void
multi_tensor_fused_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
at
::
Tensor
per_tensor_beta1
,
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_bias_correction
,
at
::
Tensor
per_tensor_eps
,
at
::
Tensor
per_tensor_weight_decay
,
float
lr
,
float
grad_scale
,
int
step
,
int
mode
)
{
using
namespace
at
;
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
4
||
tl_sz
==
5
,
"expected tensor lists of size 4 or 5"
);
if
(
tl_sz
==
5
)
{
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"dist_adam_cuda_kernel"
,
// g
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
DistAdamFunctor
<
5
,
accscalar_t
,
scalar_t_0
>
(),
per_tensor_beta1
.
DATA_PTR
<
float
>
(),
per_tensor_beta2
.
DATA_PTR
<
float
>
(),
per_tensor_bias_correction
.
DATA_PTR
<
int
>
(),
per_tensor_eps
.
DATA_PTR
<
float
>
(),
per_tensor_weight_decay
.
DATA_PTR
<
float
>
(),
lr
,
grad_scale
,
step
,
(
adamMode_t
)
mode
);
);
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"dist_adam_cuda_kernel"
,
// g
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
DistAdamFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
per_tensor_beta1
.
DATA_PTR
<
float
>
(),
per_tensor_beta2
.
DATA_PTR
<
float
>
(),
per_tensor_bias_correction
.
DATA_PTR
<
int
>
(),
per_tensor_eps
.
DATA_PTR
<
float
>
(),
per_tensor_weight_decay
.
DATA_PTR
<
float
>
(),
lr
,
grad_scale
,
step
,
(
adamMode_t
)
mode
);
);
}
C10_CUDA_CHECK
(
cudaGetLastError
());
}
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <torch/extension.h>
void
multi_tensor_lamb_compute_update_term_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_beta1
,
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
at
::
Tensor
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_scale
,
at
::
Tensor
global_grad_norm
,
const
float
max_grad_norm
);
void
multi_tensor_lamb_update_weights_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
at
::
Tensor
update_norm_offset
,
at
::
Tensor
learning_rate
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_grad_norm
,
bool
use_nvlamb
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_lamb_compute_update_term"
,
&
multi_tensor_lamb_compute_update_term_cuda
,
"Computes update term for LAMB optimizer"
);
m
.
def
(
"multi_tensor_lamb_update_weights"
,
&
multi_tensor_lamb_update_weights_cuda
,
"Applies update term for LAMB optimizer"
);
}
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
deleted
100644 → 0
View file @
2a4864d5
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
){
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
template
<
typename
T
>
__device__
__forceinline__
void
load_store
(
T
*
dst
,
T
*
src
,
int
dst_offset
,
int
src_offset
){
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LT
;
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
}
template
<
typename
FROM_T
,
typename
TO_T
>
__device__
void
convert
(
const
FROM_T
vi
,
TO_T
&
vo
)
{
vo
=
static_cast
<
TO_T
>
(
vi
);
}
template
<
>
__device__
void
convert
(
const
float
vi
,
uint8_t
&
vo
)
{
union
S
{
float
as_float
;
int
as_int
;
};
S
s
;
s
.
as_float
=
vi
;
s
.
as_int
=
s
.
as_int
&
0xFF800000
;
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_half
=
static_cast
<
at
::
Half
>
(
vi
+
s
.
as_float
/
8.0
f
);
vo
=
t
.
as_byte
[
1
];
}
template
<
>
__device__
void
convert
(
const
uint8_t
vi
,
float
&
vo
)
{
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_byte
[
0
]
=
0
;
t
.
as_byte
[
1
]
=
vi
;
vo
=
static_cast
<
float
>
(
t
.
as_half
);
}
template
<
>
__device__
void
convert
(
const
at
::
Half
vi
,
uint8_t
&
vo
)
{
union
S
{
float
as_float
;
int
as_int
;
};
S
s
;
s
.
as_float
=
static_cast
<
float
>
(
vi
);
s
.
as_int
=
s
.
as_int
&
0xFF800000
;
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_half
=
static_cast
<
at
::
Half
>
(
vi
+
s
.
as_float
/
8.0
f
);
vo
=
t
.
as_byte
[
1
];
}
template
<
>
__device__
void
convert
(
const
uint8_t
vi
,
at
::
Half
&
vo
)
{
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_byte
[
0
]
=
0
;
t
.
as_byte
[
1
]
=
vi
;
vo
=
t
.
as_half
;
}
typedef
enum
{
MOMENT_MODE_0
=
0
,
// L2 regularization mode
MOMENT_MODE_1
=
1
// Decoupled weight decay mode
}
adamMode_t
;
template
<
typename
T
,
typename
GRAD_T
,
typename
MATH_T
>
struct
DistOptLAMBStage1Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
5
>&
tl
,
const
MATH_T
*
per_tensor_beta1
,
const
MATH_T
*
per_tensor_beta2
,
const
MATH_T
*
per_tensor_beta3
,
const
int
*
per_tensor_bias_correction
,
const
int
*
step
,
const
MATH_T
*
per_tensor_epsilon
,
adamMode_t
mode
,
const
MATH_T
*
per_tensor_decay
,
const
MATH_T
*
global_scale
,
const
MATH_T
*
global_grad_norm
,
const
float
max_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
if
(
*
noop_gmem
==
1
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
combined_scale
=
*
global_scale
;
if
(
max_grad_norm
>
0
)
{
combined_scale
=
max_grad_norm
/
(
*
global_grad_norm
/
*
global_scale
+
1e-6
);
combined_scale
=
*
global_scale
/
std
::
min
((
float
)
1.0
,
combined_scale
);
}
MATH_T
beta1
=
per_tensor_beta1
[
tensor_num
];
MATH_T
beta2
=
per_tensor_beta2
[
tensor_num
];
MATH_T
beta3
=
1
-
beta1
;
MATH_T
beta1_correction
,
beta2_correction
;
if
(
per_tensor_bias_correction
[
tensor_num
]
==
1
)
{
beta1_correction
=
1
-
pow
(
beta1
,
*
step
);
beta2_correction
=
1
-
pow
(
beta2
,
*
step
);
}
else
{
beta1_correction
=
(
MATH_T
)
1.0
;
beta2_correction
=
(
MATH_T
)
1.0
;
}
MATH_T
epsilon
=
per_tensor_epsilon
[
tensor_num
];
MATH_T
decay
=
per_tensor_decay
[
tensor_num
];
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
MATH_T
*
u
=
(
MATH_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
u
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
MATH_T
r_g
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_m
[
ILP
];
MATH_T
r_v
[
ILP
];
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
g
)
&&
is_aligned
(
p
)
&&
is_aligned
(
m
)
&&
is_aligned
(
v
))
{
GRAD_T
l_g
[
ILP
];
T
l_p
[
ILP
];
T
l_m
[
ILP
];
T
l_v
[
ILP
];
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
load_store
(
l_g
,
g
,
0
,
i_start
);
if
(
decay
!=
0
)
load_store
(
l_p
,
p
,
0
,
i_start
);
load_store
(
l_m
,
m
,
0
,
i_start
);
load_store
(
l_v
,
v
,
0
,
i_start
);
// unpack
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_g
[
ii
]
=
l_g
[
ii
];
if
(
decay
==
0
)
{
r_p
[
ii
]
=
MATH_T
(
0
);
}
else
{
r_p
[
ii
]
=
l_p
[
ii
];
}
r_m
[
ii
]
=
l_m
[
ii
];
r_v
[
ii
]
=
l_v
[
ii
];
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combined_scale
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combined_scale
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
r_p
[
ii
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
l_m
[
ii
]
=
r_m
[
ii
];
l_v
[
ii
]
=
r_v
[
ii
];
}
// store
load_store
(
u
,
r_p
,
i_start
,
0
);
load_store
(
m
,
l_m
,
i_start
,
0
);
load_store
(
v
,
l_v
,
i_start
,
0
);
}
}
else
{
// see note in multi_tensor_scale_kernel.cu
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_g
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_m
[
ILP
];
MATH_T
r_v
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_g
[
ii
]
=
g
[
i
];
// special ?optimization? for lamb stage 1
if
(
decay
==
0
)
{
r_p
[
ii
]
=
MATH_T
(
0
);
}
else
{
r_p
[
ii
]
=
p
[
i
];
}
r_m
[
ii
]
=
m
[
i
];
r_v
[
ii
]
=
v
[
i
];
}
else
{
r_g
[
ii
]
=
MATH_T
(
0
);
r_p
[
ii
]
=
MATH_T
(
0
);
r_m
[
ii
]
=
MATH_T
(
0
);
r_v
[
ii
]
=
MATH_T
(
0
);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combined_scale
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combined_scale
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
r_p
[
ii
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
u
[
i
]
=
r_p
[
ii
];
m
[
i
]
=
r_m
[
ii
];
v
[
i
]
=
r_v
[
ii
];
}
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template
<
typename
T
,
typename
GRAD_T
,
typename
MATH_T
>
struct
DistOptLAMBStage2Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>&
tl
,
const
MATH_T
*
per_tensor_param_norm
,
const
MATH_T
*
per_tensor_update_norm
,
const
long
*
update_norm_offset
,
const
MATH_T
*
learning_rate
,
const
MATH_T
*
per_tensor_decay
,
const
MATH_T
*
global_grad_norm
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
if
(
*
noop_gmem
==
1
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
MATH_T
decay
=
per_tensor_decay
[
tensor_num
];
MATH_T
ratio
=
*
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
(
MATH_T
)
0.0
))
{
MATH_T
param_norm
=
per_tensor_param_norm
[
tensor_num
];
MATH_T
update_norm
=
per_tensor_update_norm
[
update_norm_offset
[
tensor_num
]];
ratio
=
(
update_norm
!=
0.0
&&
param_norm
!=
0.0
)
?
(
*
learning_rate
)
*
(
param_norm
/
update_norm
)
:
(
*
learning_rate
);
}
MATH_T
*
update
=
(
MATH_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
p_copy
=
(
GRAD_T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
p_copy
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
p
)
&&
is_aligned
(
update
))
{
T
r_p
[
ILP
];
MATH_T
r_update
[
ILP
];
GRAD_T
r_p_copy
[
ILP
];
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
load_store
(
r_p
,
p
,
0
,
i_start
);
load_store
(
r_update
,
update
,
0
,
i_start
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
static_cast
<
MATH_T
>
(
r_p
[
ii
])
-
(
ratio
*
r_update
[
ii
]);
convert
(
r_p
[
ii
],
r_p_copy
[
ii
]);
}
load_store
(
p
,
r_p
,
i_start
,
0
);
load_store
(
p_copy
,
r_p_copy
,
i_start
,
0
);
}
}
else
{
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_p
[
ILP
];
MATH_T
r_update
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_p
[
ii
]
=
p
[
i
];
r_update
[
ii
]
=
update
[
i
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
r_p
[
ii
]
-
(
ratio
*
r_update
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
p
[
i
]
=
r_p
[
ii
];
convert
(
r_p
[
ii
],
p_copy
[
i
]);
}
}
}
}
}
};
void
multi_tensor_lamb_compute_update_term_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_beta1
,
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
at
::
Tensor
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_scale
,
at
::
Tensor
global_grad_norm
,
const
float
max_grad_norm
)
{
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
DistOptLAMBStage1Functor
<
scalar_t_0
,
scalar_t_1
,
scalar_t_2
>
(),
per_tensor_beta1
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_beta2
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_beta3
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_bias_correction
.
DATA_PTR
<
int
>
(),
step
.
DATA_PTR
<
int
>
(),
per_tensor_epsilon
.
DATA_PTR
<
scalar_t_2
>
(),
(
adamMode_t
)
mode
,
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
global_scale
.
DATA_PTR
<
scalar_t_2
>
(),
global_grad_norm
.
DATA_PTR
<
scalar_t_2
>
(),
max_grad_norm
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
}
void
multi_tensor_lamb_update_weights_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
at
::
Tensor
update_norm_offset
,
at
::
Tensor
learning_rate
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_grad_norm
,
bool
use_nvlamb
)
{
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_HALF_AND_BYTE
(
tensor_lists
[
2
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
2
,
"lamb_stage_2"
,
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
DistOptLAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
,
scalar_t_2
>
(),
per_tensor_param_norm
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_update_norm
.
DATA_PTR
<
scalar_t_2
>
(),
update_norm_offset
.
DATA_PTR
<
long
>
(),
learning_rate
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
global_grad_norm
.
DATA_PTR
<
scalar_t_2
>
(),
use_nvlamb
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
}
apex/contrib/csrc/peer_memory/peer_memory.cpp
deleted
100644 → 0
View file @
2a4864d5
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"allocate_raw"
,
&
apex
::
contrib
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
contrib
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"zero"
,
&
apex
::
contrib
::
peer_memory
::
zero
,
"zero"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
m
.
def
(
"blob_view_float"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_float
,
"blob_view_float"
);
m
.
def
(
"blob_view_int"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_int
,
"blob_view_int"
);
m
.
def
(
"push_pull_halos_1d"
,
&
apex
::
contrib
::
peer_memory
::
push_pull_halos_1d
,
"push_pull_halos_1d"
);
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#include "rccl/rccl.h"
#else
#include <cooperative_groups.h>
#include "nccl.h"
#endif
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
char hostname[1024]; \
gethostname(hostname, 1024); \
printf("%s: CUDA failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,cudaGetErrorString(err)); \
} \
} while(0)
// C++17 removes 'register' storage keyword
#if __cplusplus < 201703L
#define REGISTER register
#else
#define REGISTER
#endif
namespace
{
/* Basic deleter function for from_blob function.
void deleter(void* ptr)
{
printf("deleter(ptr=%p)\n",ptr);
cudaFree(ptr);
}
*/
template
<
class
T
>
at
::
Tensor
blob_view
(
T
*
raw_ptr
,
std
::
vector
<
int64_t
>
shape
,
const
at
::
TensorOptions
&
options
,
bool
channels_last
)
{
size_t
size
=
1
;
std
::
vector
<
int64_t
>
strides
(
shape
.
size
());
if
(
channels_last
)
{
assert
(
shape
.
size
()
==
4
);
strides
[
0
]
=
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
];
strides
[
1
]
=
1
;
strides
[
2
]
=
shape
[
1
]
*
shape
[
3
];
strides
[
3
]
=
shape
[
1
];
}
else
{
int
idx
=
strides
.
size
();
for
(
auto
it
=
shape
.
rbegin
();
it
!=
shape
.
rend
();
++
it
)
{
strides
[
--
idx
]
=
size
;
size
*=
*
it
;
}
}
size
*=
sizeof
(
T
);
// TODO: Implement dynamic reuse of pooled peer memory.
// We provide no deleter function because all peer memory allocations are static in this implementation.
return
torch
::
from_blob
((
void
*
)
raw_ptr
,
shape
,
strides
,
0L
,
options
);
}
void
tensor_shape
(
at
::
Tensor
t
,
bool
explicit_nhwc
,
int
&
N
,
int
&
C
,
int
&
H
,
int
&
W
)
{
if
(
t
.
dim
()
==
3
)
{
N
=
1
;
if
(
explicit_nhwc
)
{
C
=
t
.
size
(
2
);
H
=
t
.
size
(
0
);
W
=
t
.
size
(
1
);
}
else
{
C
=
t
.
size
(
0
);
H
=
t
.
size
(
1
);
W
=
t
.
size
(
2
);
}
}
else
if
(
t
.
dim
()
==
4
)
{
if
(
explicit_nhwc
)
{
N
=
t
.
size
(
0
);
C
=
t
.
size
(
3
);
H
=
t
.
size
(
1
);
W
=
t
.
size
(
2
);
}
else
{
N
=
t
.
size
(
0
);
C
=
t
.
size
(
1
);
H
=
t
.
size
(
2
);
W
=
t
.
size
(
3
);
}
}
else
{
printf
(
"%s;%d - t.dim() must be either 3 or 4 (was %d)
\n
"
,
__FILE__
,
__LINE__
,
t
.
dim
());
assert
(
t
.
dim
()
==
3
||
t
.
dim
()
==
4
);
}
}
void
tensor_strides
(
at
::
Tensor
t
,
bool
explicit_nhwc
,
int
&
stride_N
,
int
&
stride_C
,
int
&
stride_H
,
int
&
stride_W
)
{
if
(
t
.
dim
()
==
3
)
{
if
(
explicit_nhwc
)
{
stride_C
=
t
.
stride
(
2
);
stride_H
=
t
.
stride
(
0
);
stride_W
=
t
.
stride
(
1
);
}
else
{
stride_C
=
t
.
stride
(
0
);
stride_H
=
t
.
stride
(
1
);
stride_W
=
t
.
stride
(
2
);
}
stride_N
=
t
.
size
(
0
)
*
t
.
size
(
1
)
*
t
.
size
(
2
);
}
else
if
(
t
.
dim
()
==
4
)
{
if
(
explicit_nhwc
)
{
stride_N
=
t
.
stride
(
0
);
stride_C
=
t
.
stride
(
3
);
stride_H
=
t
.
stride
(
1
);
stride_W
=
t
.
stride
(
2
);
}
else
{
stride_N
=
t
.
stride
(
0
);
stride_C
=
t
.
stride
(
1
);
stride_H
=
t
.
stride
(
2
);
stride_W
=
t
.
stride
(
3
);
}
}
else
{
printf
(
"%s;%d - t.dim() must be either 3 or 4 (was %d)
\n
"
,
__FILE__
,
__LINE__
,
t
.
dim
());
assert
(
t
.
dim
()
==
3
||
t
.
dim
()
==
4
);
}
}
template
<
class
T
>
__device__
void
__zero
(
T
*
dst
)
{
*
dst
=
T
(
0
);
}
__device__
void
__zero
(
int4
*
dst
)
{
int4
v
;
v
.
x
=
v
.
y
=
v
.
z
=
v
.
w
=
0
;
*
dst
=
v
;
}
template
<
class
T
,
bool
is_HWC
,
bool
zero
>
__device__
void
strided_copy_kernel
(
T
*
dst
,
const
int
dst_stride_C
,
const
int
dst_stride_H
,
const
int
dst_stride_W
,
const
T
*
src
,
const
int
src_stride_C
,
const
int
src_stride_H
,
const
int
src_stride_W
,
const
int
NC
,
const
int
NH
,
const
int
NW
)
{
size_t
tot_num_threads
=
gridDim
.
x
*
blockDim
.
x
;
size_t
thread_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
count
=
NC
*
NH
*
NW
;
for
(
size_t
i
=
thread_id
;
i
<
count
;
i
+=
tot_num_threads
)
{
size_t
c
,
h
,
w
;
if
(
is_HWC
)
{
w
=
i
/
NC
;
c
=
i
-
w
*
NC
;
h
=
w
/
NW
;
w
=
w
-
h
*
NW
;
}
else
{
h
=
i
/
NW
;
w
=
i
-
h
*
NW
;
c
=
h
/
NH
;
h
=
h
-
c
*
NH
;
}
size_t
dst_off
=
c
*
dst_stride_C
+
h
*
dst_stride_H
+
w
*
dst_stride_W
;
if
(
zero
)
{
__zero
(
dst
+
dst_off
);
}
else
{
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
dst
[
dst_off
]
=
src
[
src_off
];
}
}
}
template
<
bool
top_zero
,
bool
btm_zero
>
__device__
void
checked_signal
(
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
{
cg
::
this_grid
().
sync
();
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
// flush all writes to global memory
__threadfence_system
();
// wait for top or bottom neighbor to clear signal
REGISTER
int
r1
,
r2
,
r3
,
r4
;
if
(
!
(
top_zero
||
btm_zero
))
{
bool
top_zeroed
=
false
,
top_done
=
false
;
bool
btm_zeroed
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
}
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
}
}
while
(
!
top_done
||
!
btm_done
);
}
else
if
(
top_zero
)
{
bool
btm_zeroed
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
btm_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal2_flag
);
r2
=
__builtin_nontemporal_load
(
signal2_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal2_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal2_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
(
btm_zeroed
==
btm_done
);
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal2_flag
);
__builtin_nontemporal_store
(
v2
,
signal2_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal2_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal2_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
btm_done
=
true
;
}
}
while
(
!
btm_done
);
}
else
if
(
btm_zero
)
{
bool
top_zeroed
=
false
,
top_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
signal1_flag
);
r2
=
__builtin_nontemporal_load
(
signal1_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
signal1_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
signal1_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
#endif
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
}
while
(
top_zeroed
==
top_done
);
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
v1
,
signal1_flag
);
__builtin_nontemporal_store
(
v2
,
signal1_flag
+
1
);
__builtin_nontemporal_store
(
v3
,
signal1_flag
+
2
);
__builtin_nontemporal_store
(
v4
,
signal1_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
#endif
top_done
=
true
;
}
}
while
(
!
top_done
);
}
}
}
__device__
void
wait_for
(
volatile
int
*
wait_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
{
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
REGISTER
int
r1
,
r2
,
r3
,
r4
;
// wait for senders to signal their output is read
do
{
#ifdef __HIP_PLATFORM_HCC__
r1
=
__builtin_nontemporal_load
(
wait_flag
);
r2
=
__builtin_nontemporal_load
(
wait_flag
+
1
);
r3
=
__builtin_nontemporal_load
(
wait_flag
+
2
);
r4
=
__builtin_nontemporal_load
(
wait_flag
+
3
);
#else
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
#endif
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
cg
::
this_grid
().
sync
();
// all threads wait for main
}
__device__
void
clear_flag
(
volatile
int
*
wait_flag
)
{
cg
::
this_grid
().
sync
();
// wait for all threads in kernel to finish
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
REGISTER
int
r1
,
r2
,
r3
,
r4
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store
(
r1
,
wait_flag
);
__builtin_nontemporal_store
(
r2
,
wait_flag
+
1
);
__builtin_nontemporal_store
(
r3
,
wait_flag
+
2
);
__builtin_nontemporal_store
(
r4
,
wait_flag
+
3
);
#else
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
#endif
}
}
template
<
class
T
,
bool
is_HWC
,
bool
top_zero
,
bool
btm_zero
>
#if __CUDA_ARCH__ == 700 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900
__launch_bounds__
(
128
,
16
)
#endif
__global__
void
push_pull_halos_1d_kernel
(
// top halo,
const
T
*
toh
,
int
toh_stride_C
,
int
toh_stride_H
,
int
toh_stride_W
,
// top output halo
T
*
tox
,
int
tox_stride_C
,
int
tox_stride_H
,
int
tox_stride_W
,
// top output tx buffer
T
*
tix
,
int
tix_stride_C
,
int
tix_stride_H
,
int
tix_stride_W
,
// top input tx buffer
T
*
tih
,
int
tih_stride_C
,
int
tih_stride_H
,
int
tih_stride_W
,
// top input halo
// btm halo
const
T
*
boh
,
int
boh_stride_C
,
int
boh_stride_H
,
int
boh_stride_W
,
// btm output halo
T
*
box
,
int
box_stride_C
,
int
box_stride_H
,
int
box_stride_W
,
// btm output tx buffer
T
*
bix
,
int
bix_stride_C
,
int
bix_stride_H
,
int
bix_stride_W
,
// btm input tx buffer
T
*
bih
,
int
bih_stride_C
,
int
bih_stride_H
,
int
bih_stride_W
,
// btm input halo
// dimensions
int
NC
,
int
NH
,
int
NW
,
// signals
int
*
signal1_flag
,
int
*
signal2_flag
,
int
*
wait1_flag
,
int
*
wait2_flag
)
{
// push top output halo to transfer buffer
if
(
!
top_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
// push btm output halo to transfer buffer
if
(
!
btm_zero
)
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
if
(
!
(
top_zero
||
btm_zero
))
{
checked_signal
<
false
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
top_zero
)
{
checked_signal
<
true
,
false
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
else
if
(
btm_zero
)
{
checked_signal
<
false
,
true
>
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
}
// pull top halo from transfer buffer in peer memory to input
if
(
top_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
}
else
{
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
}
// pull btm halo from transfer buffer in peer memory to input
if
(
btm_zero
)
{
strided_copy_kernel
<
T
,
is_HWC
,
true
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
}
else
{
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
,
false
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int
new_counter
=
0
;
double
elapsed
=
0
;
clock_t
start
=
clock
();
do
{
clock_t
now
=
clock
();
elapsed
=
(
double
)(
now
-
start
)
*
1e9
/
CLOCKS_PER_SEC
;
++
new_counter
;
}
while
(
elapsed
<
(
double
)
delay_nanoseconds
);
*
counter
=
new_counter
;
}
}
}
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
)
{
float
*
ptr
=
0L
;
cudaMalloc
(
&
ptr
,
size
);
cudaMemset
(
ptr
,
0
,
size
);
return
(
int64_t
)
ptr
;
}
void
free_raw
(
int64_t
raw
)
{
cudaFree
((
void
*
)
raw
);
}
void
zero
(
int64_t
raw
,
int64_t
size
)
{
cudaMemset
((
void
*
)
raw
,
0
,
size
);
}
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
)
{
cudaIpcMemHandle_t
mem_handle
;
CUDACHECK
(
cudaIpcGetMemHandle
(
&
mem_handle
,
(
void
*
)
raw
)
);
const
int
n
=
sizeof
(
cudaIpcMemHandle_t
);
auto
address_tensor
=
torch
::
empty
({
n
},
torch
::
dtype
(
torch
::
kUInt8
));
auto
address_tensor_p
=
address_tensor
.
data_ptr
<
uint8_t
>
();
memcpy
(
address_tensor_p
,
(
uint8_t
*
)
&
mem_handle
,
n
);
return
address_tensor
;
}
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
)
{
int
peer_group_size
=
ipc_addresses
.
size
(
0
);
std
::
vector
<
int64_t
>
results
(
peer_group_size
);
for
(
int
i
=
0
;
i
<
peer_group_size
;
++
i
)
{
if
(
i
!=
peer_rank
)
{
cudaIpcMemHandle_t
mem_handle
;
memcpy
(
&
mem_handle
,
ipc_addresses
.
index
({
i
}).
data_ptr
<
uint8_t
>
(),
sizeof
(
cudaIpcMemHandle_t
));
void
*
p
=
0L
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
p
,
mem_handle
,
cudaIpcMemLazyEnablePeerAccess
)
);
results
[
i
]
=
(
int64_t
)
p
;
}
else
{
results
[
i
]
=
(
int64_t
)
raw
;
}
}
return
results
;
}
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
at
::
Half
>
((
at
::
Half
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
),
channels_last
);
}
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
float
>
((
float
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
),
channels_last
);
}
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
int
>
((
int
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
),
channels_last
);
}
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
)
{
// basic checks of inputs
TORCH_CHECK
(
top_out_halo
.
is_cuda
());
TORCH_CHECK
(
top_out_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_halo
.
is_cuda
());
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
waits
.
is_cuda
());
TORCH_CHECK
(
!
(
top_zero
&&
btm_zero
));
// shapes and strides
int
toh_N
,
toh_C
,
toh_H
,
toh_W
;
tensor_shape
(
top_out_halo
,
explicit_nhwc
,
toh_N
,
toh_C
,
toh_H
,
toh_W
);
int
tox_N
,
tox_C
,
tox_H
,
tox_W
;
tensor_shape
(
top_out_tx
,
explicit_nhwc
,
tox_N
,
tox_C
,
tox_H
,
tox_W
);
int
tix_N
,
tix_C
,
tix_H
,
tix_W
;
tensor_shape
(
top_inp_tx
,
explicit_nhwc
,
tix_N
,
tix_C
,
tix_H
,
tix_W
);
int
tih_N
,
tih_C
,
tih_H
,
tih_W
;
tensor_shape
(
top_inp_halo
,
explicit_nhwc
,
tih_N
,
tih_C
,
tih_H
,
tih_W
);
TORCH_CHECK
(
(
toh_N
==
tox_N
&&
tox_N
==
tix_N
&&
tix_N
==
tih_N
)
&&
(
toh_C
==
tox_C
&&
tox_C
==
tix_C
&&
tix_C
==
tih_C
)
&&
(
toh_H
==
tox_H
&&
tox_H
==
tix_H
&&
tix_H
==
tih_H
)
&&
(
toh_W
==
tox_W
&&
tox_W
==
tix_W
&&
tix_W
==
tih_W
));
int
boh_N
,
boh_C
,
boh_H
,
boh_W
;
tensor_shape
(
btm_out_halo
,
explicit_nhwc
,
boh_N
,
boh_C
,
boh_H
,
boh_W
);
int
box_N
,
box_C
,
box_H
,
box_W
;
tensor_shape
(
btm_out_tx
,
explicit_nhwc
,
box_N
,
box_C
,
box_H
,
box_W
);
int
bix_N
,
bix_C
,
bix_H
,
bix_W
;
tensor_shape
(
btm_inp_tx
,
explicit_nhwc
,
bix_N
,
bix_C
,
bix_H
,
bix_W
);
int
bih_N
,
bih_C
,
bih_H
,
bih_W
;
tensor_shape
(
btm_inp_halo
,
explicit_nhwc
,
bih_N
,
bih_C
,
bih_H
,
bih_W
);
TORCH_CHECK
(
(
boh_N
==
box_N
&&
box_N
==
bix_N
&&
bix_N
==
bih_N
)
&&
(
boh_C
==
box_C
&&
box_C
==
bix_C
&&
bix_C
==
bih_C
)
&&
(
boh_H
==
box_H
&&
box_H
==
bix_H
&&
bix_H
==
bih_H
)
&&
(
boh_W
==
box_W
&&
box_W
==
bix_W
&&
bix_W
==
bih_W
));
TORCH_CHECK
(
(
toh_N
==
boh_N
)
&&
(
toh_C
==
boh_C
)
&&
(
toh_H
==
boh_H
)
&&
(
toh_W
==
boh_W
));
int
NC
=
toh_C
,
NH
=
toh_H
,
NW
=
toh_W
;
if
(
diagnostics
)
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
int
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
;
tensor_strides
(
top_out_halo
,
explicit_nhwc
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
int
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
;
tensor_strides
(
top_out_tx
,
explicit_nhwc
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
int
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
;
tensor_strides
(
top_inp_tx
,
explicit_nhwc
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
int
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
;
tensor_strides
(
top_inp_halo
,
explicit_nhwc
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
int
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
;
tensor_strides
(
btm_out_halo
,
explicit_nhwc
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
int
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
;
tensor_strides
(
btm_out_tx
,
explicit_nhwc
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
int
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
;
tensor_strides
(
btm_inp_tx
,
explicit_nhwc
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
int
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
;
tensor_strides
(
btm_inp_halo
,
explicit_nhwc
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
// determine if nhwc
auto
is_nhwc
=
(
toh_stride_C
==
1
)
?
true
:
false
;
if
(
diagnostics
)
printf
(
"is_nhwc = %s
\n
"
,
is_nhwc
?
"true"
:
"false"
);
// figure out launch parameters
int
device
;
cudaGetDevice
(
&
device
);
cudaDeviceProp
prop
;
cudaGetDeviceProperties
(
&
prop
,
device
);
assert
(
numSM
>
0
&&
numSM
<=
prop
.
multiProcessorCount
);
auto
current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
numThreads
=
128
;
dim3
block
(
numThreads
,
1
,
1
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
top_out_halo
.
scalar_type
(),
"push_pull_halos_1d_kernel"
,
[
&
]{
if
(
diagnostics
)
printf
(
"size(scalar_t) = %ld
\n
"
,
sizeof
(
scalar_t
));
scalar_t
*
toh_p
=
top_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tox_p
=
top_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tix_p
=
top_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tih_p
=
top_inp_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
boh_p
=
btm_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
box_p
=
btm_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bix_p
=
btm_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bih_p
=
btm_inp_halo
.
data_ptr
<
scalar_t
>
();
if
(
diagnostics
)
printf
(
"waypoint1
\n
"
);
int
*
top_signal_p
=
top_signal
.
data_ptr
<
int
>
()
+
4
;
int
*
btm_signal_p
=
btm_signal
.
data_ptr
<
int
>
();
int
*
top_wait_p
=
waits
.
data_ptr
<
int
>
();
int
*
btm_wait_p
=
waits
.
data_ptr
<
int
>
()
+
4
;
if
(
diagnostics
)
printf
(
"waypoint2
\n
"
);
// do int4 vector loads if channel count permits
int
elem_size_in_bytes
=
toh_C
*
sizeof
(
scalar_t
);
int
elem_size_in_int4
=
(
elem_size_in_bytes
/
16
);
if
(
diagnostics
)
printf
(
"elem_size_in_bytes = %d, elem_size_in_int4 = %d
\n
"
,
elem_size_in_bytes
,
elem_size_in_int4
);
if
(
is_nhwc
&&
elem_size_in_int4
*
16
==
elem_size_in_bytes
)
{
// can do int4 transfers
int
divisor
=
toh_C
/
elem_size_in_int4
;
if
(
diagnostics
)
printf
(
"CAN DO INT4 :: divisor = %d
\n
"
,
divisor
);
toh_stride_N
/=
divisor
;
toh_stride_H
/=
divisor
;
toh_stride_W
/=
divisor
;
tox_stride_N
/=
divisor
;
tox_stride_H
/=
divisor
;
tox_stride_W
/=
divisor
;
tix_stride_N
/=
divisor
;
tix_stride_H
/=
divisor
;
tix_stride_W
/=
divisor
;
tih_stride_N
/=
divisor
;
tih_stride_H
/=
divisor
;
tih_stride_W
/=
divisor
;
boh_stride_N
/=
divisor
;
boh_stride_H
/=
divisor
;
boh_stride_W
/=
divisor
;
box_stride_N
/=
divisor
;
box_stride_H
/=
divisor
;
box_stride_W
/=
divisor
;
bix_stride_N
/=
divisor
;
bix_stride_H
/=
divisor
;
bix_stride_W
/=
divisor
;
bih_stride_N
/=
divisor
;
bih_stride_H
/=
divisor
;
bih_stride_W
/=
divisor
;
NC
/=
divisor
;
if
(
diagnostics
)
{
printf
(
"divisor=%d
\n
"
,
divisor
);
printf
(
"toh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
printf
(
"tox_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
printf
(
"tix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
printf
(
"tih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
printf
(
"boh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
printf
(
"box_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
printf
(
"bix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
printf
(
"bih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
}
void
*
kernelArgs
[]
=
{
(
int4
**
)
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
(
int4
**
)
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
(
int4
**
)
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
(
int4
**
)
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
(
int4
**
)
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
(
int4
**
)
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
(
int4
**
)
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
(
int4
**
)
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
if
(
top_zero
)
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
void
*
kernelArgs
[]
=
{
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
int
numBlocksPerSm
;
if
(
is_nhwc
)
{
if
(
top_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
else
{
if
(
top_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
true
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
if
(
btm_zero
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
,
false
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
#endif
}
}
}
}
);
}
}
}
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
deleted
100644 → 0
View file @
2a4864d5
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _peer_memory_h_
#define _peer_memory_h_
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
);
void
free_raw
(
int64_t
raw
);
void
zero
(
int64_t
raw
,
int64_t
size
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
bool
top_zero
,
// true if top halo should be zeroed
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
bool
btm_zero
,
// true if btm halo should be zeroed
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
);
}
}
}
#endif
apex/contrib/csrc/transducer/transducer_joint.cpp
deleted
100755 → 0
View file @
2a4864d5
#include <torch/extension.h>
#include <ATen/Functions.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
transducer_joint_cuda_forward
(
torch
::
Tensor
f
,
torch
::
Tensor
g
,
torch
::
Tensor
fLen
,
torch
::
Tensor
gLen
,
torch
::
Tensor
batchOffset
,
int64_t
packedBatch
,
int
opt
,
bool
packOutput
,
bool
relu
,
bool
dropout
,
float
dropoutProb
,
int
tileSize
);
std
::
vector
<
torch
::
Tensor
>
transducer_joint_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
in
,
torch
::
Tensor
fLen
,
torch
::
Tensor
gLen
,
torch
::
Tensor
batchOffset
,
int
maxFLen
,
int
maxGLen
,
bool
packOutput
,
float
scale
);
std
::
vector
<
torch
::
Tensor
>
transducer_joint_forward
(
torch
::
Tensor
f
,
torch
::
Tensor
g
,
torch
::
Tensor
fLen
,
torch
::
Tensor
gLen
,
torch
::
Tensor
batchOffset
,
int64_t
packedBatch
,
int
opt
,
bool
packOutput
,
bool
relu
,
bool
dropout
,
float
dropoutProb
,
int
tileSize
)
{
CHECK_INPUT
(
f
);
CHECK_INPUT
(
g
);
CHECK_INPUT
(
fLen
);
CHECK_INPUT
(
gLen
);
if
(
packOutput
)
CHECK_INPUT
(
batchOffset
);
return
transducer_joint_cuda_forward
(
f
,
g
,
fLen
,
gLen
,
batchOffset
,
packedBatch
,
opt
,
packOutput
,
relu
,
dropout
,
dropoutProb
,
tileSize
);
}
std
::
vector
<
torch
::
Tensor
>
transducer_joint_backward
(
std
::
vector
<
torch
::
Tensor
>
in
,
torch
::
Tensor
fLen
,
torch
::
Tensor
gLen
,
torch
::
Tensor
batchOffset
,
int
maxFLen
,
int
maxGLen
,
bool
packOutput
,
float
scale
)
{
for
(
auto
t
:
in
){
CHECK_INPUT
(
t
);
}
CHECK_INPUT
(
fLen
);
CHECK_INPUT
(
gLen
);
if
(
packOutput
)
CHECK_INPUT
(
batchOffset
);
return
transducer_joint_cuda_backward
(
in
,
fLen
,
gLen
,
batchOffset
,
maxFLen
,
maxGLen
,
packOutput
,
scale
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
transducer_joint_forward
,
"transducer joint forward (CUDA)"
);
m
.
def
(
"backward"
,
&
transducer_joint_backward
,
"transducer joint backward (CUDA)"
);
}
\ No newline at end of file
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
deleted
100755 → 0
View file @
2a4864d5
#include <cuda.h>
#include <curand_kernel.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/AccumulateType.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <c10/macros/Macros.h>
#include "philox.cuh"
#ifdef __HIP_PLATFORM_HCC__
#define SHFL_DOWN(val, laneMask, width) __shfl_down(val, laneMask, width)
#else
#define SHFL_DOWN(val, laneMask, width) __shfl_down_sync(0xffffffff, val, laneMask, width)
#endif
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize.
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
warpReduce
(
scalar_t
x
,
int
width
=
C10_WARP_SIZE
){
for
(
unsigned
offset
=
width
/
2
;
offset
>
0
;
offset
/=
2
){
x
+=
SHFL_DOWN
(
x
,
offset
,
width
);
}
return
x
;
}
inline
int
largestPowerOfTwo
(
int
x
){
int
y
=
1
;
while
(
y
<=
x
)
y
<<=
1
;
return
y
>>
1
;
}
/*
Figure out vectorization type for masks.
Similar to how PyTorch figures out acc_t here:
aten/src/ATen/AccumulateType.h
*/
template
<
int
V
>
struct
MaskVecType
{
};
template
<
>
struct
MaskVecType
<
1
>
{
using
type
=
uint8_t
;
};
template
<
>
struct
MaskVecType
<
2
>
{
using
type
=
uint16_t
;
};
template
<
>
struct
MaskVecType
<
4
>
{
using
type
=
uint32_t
;
};
template
<
int
V
>
using
mvec_type
=
typename
MaskVecType
<
V
>::
type
;
// Helper class to calculate pointer offset that can be shared by different flavors of kernels.
// For fwd, batch offset and stride are different for packing and non-packing mode.
struct
OffsetCalFwd
{
__device__
__forceinline__
OffsetCalFwd
(
int64_t
batch
,
const
int64_t
*
batchOffset
,
int64_t
maxFLen
,
int64_t
maxGLen
,
int64_t
gLen
,
int64_t
hiddenSize
,
bool
packOutput
)
:
batch
(
batch
),
batchOffset
(
batchOffset
),
maxFLen
(
maxFLen
),
maxGLen
(
maxGLen
),
gLen
(
gLen
),
hiddenSize
(
hiddenSize
),
packOutput
(
packOutput
)
{}
int64_t
batch
;
const
int64_t
*
batchOffset
;
int64_t
maxFLen
;
int64_t
maxGLen
;
int64_t
gLen
;
int64_t
hiddenSize
;
bool
packOutput
;
__device__
__forceinline__
int64_t
getBatchOffset
(){
return
packOutput
?
((
batch
==
0
)
?
0
:
batchOffset
[
batch
-
1
])
*
hiddenSize
:
batch
*
maxFLen
*
maxGLen
*
hiddenSize
;
}
__device__
__forceinline__
int64_t
getStrideF
(){
return
packOutput
?
gLen
*
hiddenSize
:
maxGLen
*
hiddenSize
;
}
};
// Helper class to calculate pointer offset that can be shared by different flavors of kernels
// For bwd, batch offset and stride are different for packing and non-packing mode.
// The reducion is done for two input tensors. Therefore, generating two sets of offsets
// according to bwdFasterDim can lead to a unified implementation in the actual kernel.
struct
OffsetCalBwd
{
__device__
__forceinline__
OffsetCalBwd
(
int64_t
batch
,
const
int64_t
*
batchOffset
,
const
int
*
fLen
,
const
int
*
gLen
,
int64_t
maxFLen
,
int64_t
maxGLen
,
int64_t
hiddenSize
,
bool
packOutput
,
bool
bwdFasterDim
)
:
batch
(
batch
),
batchOffset
(
batchOffset
),
maxFLen
(
maxFLen
),
maxGLen
(
maxGLen
),
fLen
(
fLen
),
gLen
(
gLen
),
hiddenSize
(
hiddenSize
),
packOutput
(
packOutput
),
bwdFasterDim
(
bwdFasterDim
)
{}
int64_t
batch
;
const
int64_t
*
batchOffset
;
const
int
*
fLen
;
const
int
*
gLen
;
int64_t
maxFLen
;
int64_t
maxGLen
;
int64_t
hiddenSize
;
bool
packOutput
;
bool
bwdFasterDim
;
// whether doing bwd on the faster moving dimension
__device__
__forceinline__
int64_t
getBatchOffset
(){
return
packOutput
?
((
batch
==
0
)
?
0
:
batchOffset
[
batch
-
1
])
*
hiddenSize
:
batch
*
maxFLen
*
maxGLen
*
hiddenSize
;
}
__device__
__forceinline__
int64_t
getMaxXLen
(){
return
bwdFasterDim
?
maxGLen
:
maxFLen
;
}
__device__
__forceinline__
auto
getMyXLen
()
->
decltype
(
gLen
[
batch
]){
return
bwdFasterDim
?
gLen
[
batch
]
:
fLen
[
batch
];
}
__device__
__forceinline__
auto
getMyYLen
()
->
decltype
(
gLen
[
batch
]){
return
bwdFasterDim
?
fLen
[
batch
]
:
gLen
[
batch
];
}
__device__
__forceinline__
int64_t
getStrideX
(){
return
bwdFasterDim
?
hiddenSize
:
((
packOutput
?
gLen
[
batch
]
:
maxGLen
)
*
hiddenSize
);
}
__device__
__forceinline__
int64_t
getStrideY
(){
return
bwdFasterDim
?
((
packOutput
?
gLen
[
batch
]
:
maxGLen
)
*
hiddenSize
)
:
hiddenSize
;
}
};
// Vanila transducer joint forward kernel
// Detail of this joint function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// f is a tensor of shape [batch, T, H]
// g is a tensor of shape [batch, U, H]
// the transducer joint does
// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
// The resultant tensor is of shape [batch, T, U, H]
// Each thread block is working on one "batch" of data in the output tensor, [batch, t, u, :]
// This joint function can optionally pack the output where the output tensor with a shape of
// [B, T, U, H] is packed into [B_packed, H].
// Don't-care region (t > fLen) or (u > gLen) is removed.
// To enable packing, the starting offset for each batch need to be specified with batchOffset.
template
<
typename
scalar_t
,
class
OffsetCal
>
__global__
void
transducer_joint_forward
(
const
scalar_t
*
f
,
const
scalar_t
*
g
,
const
int
*
fLen
,
const
int
*
gLen
,
const
int64_t
*
batchOffset
,
int64_t
maxFLen
,
int64_t
maxGLen
,
int64_t
hiddenSize
,
bool
packOutput
,
scalar_t
*
sum
)
{
const
int
batch
=
blockIdx
.
z
;
const
int
t
=
blockIdx
.
y
;
const
int
u
=
blockIdx
.
x
;
const
auto
myFLen
=
fLen
[
batch
];
const
auto
myGLen
=
gLen
[
batch
];
OffsetCal
offsetCal
(
batch
,
batchOffset
,
maxFLen
,
maxGLen
,
myGLen
,
hiddenSize
,
packOutput
);
const
auto
myBatchOffset
=
offsetCal
.
getBatchOffset
();
const
auto
strideF
=
offsetCal
.
getStrideF
();
scalar_t
const
*
myF
=
f
+
batch
*
maxFLen
*
hiddenSize
+
t
*
hiddenSize
;
scalar_t
const
*
myG
=
g
+
batch
*
maxGLen
*
hiddenSize
+
u
*
hiddenSize
;
scalar_t
*
mySum
=
sum
+
myBatchOffset
+
t
*
strideF
+
u
*
hiddenSize
;
if
(
t
<
myFLen
and
u
<
myGLen
){
#pragma unroll
for
(
int
h
=
threadIdx
.
x
;
h
<
hiddenSize
;
h
+=
blockDim
.
x
){
if
(
h
<
hiddenSize
){
mySum
[
h
]
=
myF
[
h
]
+
myG
[
h
];
}
}
}
else
if
(
packOutput
==
false
and
t
<
maxFLen
and
u
<
maxGLen
){
// Need to write finite data to don't-care region because we instantiate the result tensor
// with torch::empty for performance reasons. Even though it is don't-care region, the
// contents need to be finite, otherwise could lead to NaN in WGRAD.
// In packing mode, this write is no longer necessary as we remove the don't-care region
// from the output.
// Picking -1 (over 0) here for ease of testing.
#pragma unroll
for
(
int
h
=
threadIdx
.
x
;
h
<
hiddenSize
;
h
+=
blockDim
.
x
){
if
(
h
<
hiddenSize
){
mySum
[
h
]
=
-
1
;
}
}
}
}
/*
Tiled version of the joint forward kernel
Detail of this joint function can be found in:
[1] Sequence Transduction with Recurrent Neural Networks.
f is a tensor of shape [batch, T, H]
g is a tensor of shape [batch, U, H]
the transducer joint does
sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
The resultant tensor is of shape [batch, T, U, H]
Each thread is working on a tile of the shape of tileF x tileG in the result tensor.
The input for the tile is first loaded in the register and is reused tileG and tileF times.
This joint function can optionally pack the output where the output tensor with a shape of
[B, T, U, H] is packed into [B_packed, H].
Don't-care region (t > fLen) or (u > gLen) is removed.
To enable packing, the starting offset for each batch need to be specified with batchOffset.
Optionally this joint function performs ReLU and/or dropout on the joint output, which is
controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating
pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint
function is a masked operation, which is controlled by the template argument masked. In this case,
masks are saved to backward.
*/
template
<
typename
scalar_t
,
int
tileF
,
int
tileG
,
int
U
,
class
OffsetCal
,
bool
masked
>
__global__
void
transducer_joint_tiled_forward
(
const
scalar_t
*
f
,
const
scalar_t
*
g
,
const
int
*
fLen
,
const
int
*
gLen
,
const
int64_t
*
batchOffset
,
int64_t
maxFLen
,
int64_t
maxGLen
,
int64_t
hiddenSize
,
int64_t
hiddenPerBlock
,
bool
packOutput
,
bool
relu
,
bool
dropout
,
float
p
,
at
::
PhiloxCudaState
philoxArgs
,
scalar_t
*
sum
,
uint8_t
*
mask
)
{
static_assert
(
U
==
4
,
"U has to be 4, as random numbers are generated in batch of 4"
);
const
int
batch
=
blockIdx
.
z
;
const
int
t
=
blockIdx
.
y
*
tileF
;
const
int
hiddenBlock
=
(
hiddenSize
+
hiddenPerBlock
-
1
)
/
hiddenPerBlock
;
const
int
u
=
blockIdx
.
x
/
hiddenBlock
*
tileG
;
const
int
hOffset
=
(
blockIdx
.
x
%
hiddenBlock
)
*
hiddenPerBlock
;
const
int
h
=
threadIdx
.
x
;
const
auto
myFLen
=
fLen
[
batch
];
const
auto
myGLen
=
gLen
[
batch
];
OffsetCal
offsetCal
(
batch
,
batchOffset
,
maxFLen
,
maxGLen
,
myGLen
,
hiddenSize
,
packOutput
);
const
auto
myBatchOffset
=
offsetCal
.
getBatchOffset
();
const
auto
strideF
=
offsetCal
.
getStrideF
();
scalar_t
const
*
myF
=
f
+
batch
*
maxFLen
*
hiddenSize
+
t
*
hiddenSize
+
hOffset
;
scalar_t
const
*
myG
=
g
+
batch
*
maxGLen
*
hiddenSize
+
u
*
hiddenSize
+
hOffset
;
scalar_t
*
mySum
=
sum
+
myBatchOffset
+
t
*
strideF
+
u
*
hiddenSize
+
hOffset
;
uint8_t
*
myMask
=
mask
+
myBatchOffset
+
t
*
strideF
+
u
*
hiddenSize
+
hOffset
;
// The following code is only needed for dropout. We try to bypass them as much as possible.
auto
seeds
=
masked
?
at
::
cuda
::
philox
::
unpack
(
philoxArgs
)
:
std
::
make_tuple
(
static_cast
<
uint64_t
>
(
0
),
static_cast
<
uint64_t
>
(
0
));
uint64_t
tid
=
masked
?
(
static_cast
<
uint64_t
>
(
blockIdx
.
z
)
*
gridDim
.
y
*
gridDim
.
x
+
blockIdx
.
y
*
gridDim
.
x
+
blockIdx
.
x
)
*
blockDim
.
x
+
threadIdx
.
x
:
0
;
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
tid
,
std
::
get
<
1
>
(
seeds
));
scalar_t
scale
=
masked
?
((
p
==
0
)
?
0
:
1
/
p
)
:
0
;
bool
dropoutMask
[
U
];
if
(
t
<
myFLen
and
u
<
myGLen
and
hOffset
+
h
<
hiddenSize
){
// register buffers for tiled input reuse
scalar_t
fBuffer
[
tileF
],
gBuffer
[
tileG
];
for
(
int
i
=
0
;
i
<
tileF
;
++
i
){
if
(
t
+
i
<
myFLen
)
fBuffer
[
i
]
=
myF
[
i
*
hiddenSize
+
h
];
}
for
(
int
j
=
0
;
j
<
tileG
;
++
j
){
if
(
u
+
j
<
myGLen
)
gBuffer
[
j
]
=
myG
[
j
*
hiddenSize
+
h
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
tileF
;
++
i
){
if
(
t
+
i
<
myFLen
){
#pragma unroll
for
(
int
j
=
0
;
j
<
tileG
;
++
j
){
int
idx
=
i
*
tileG
+
j
;
if
(
masked
and
dropout
and
idx
%
U
==
0
){
// For performance, generate 4 random numbers in one shot
// auto rand4 = curand_uniform4(&state);
auto
rand4
=
uniform4
(
ph
());
dropoutMask
[
0
]
=
rand4
.
x
<
p
;
dropoutMask
[
1
]
=
rand4
.
y
<
p
;
dropoutMask
[
2
]
=
rand4
.
z
<
p
;
dropoutMask
[
3
]
=
rand4
.
w
<
p
;
}
if
(
u
+
j
<
myGLen
){
scalar_t
out
=
fBuffer
[
i
]
+
gBuffer
[
j
];
if
(
masked
){
// Apply ReLU here when relu is True
bool
localMask
=
relu
?
(
out
>
0
)
:
1
;
localMask
=
dropout
?
localMask
&
dropoutMask
[
idx
%
U
]
:
localMask
;
out
=
dropout
?
out
*
localMask
*
scale
:
out
*
localMask
;
myMask
[
i
*
strideF
+
j
*
hiddenSize
+
h
]
=
static_cast
<
uint8_t
>
(
localMask
);
}
mySum
[
i
*
strideF
+
j
*
hiddenSize
+
h
]
=
out
;
}
else
if
(
packOutput
==
false
and
u
+
j
<
maxGLen
)
mySum
[
i
*
strideF
+
j
*
hiddenSize
+
h
]
=
-
1
;
}
}
else
if
(
packOutput
==
false
and
t
+
i
<
maxFLen
){
// Again need to write finite data to don't-care region
#pragma unroll
for
(
int
j
=
0
;
j
<
tileG
;
++
j
){
if
(
u
+
j
<
maxGLen
)
mySum
[
i
*
strideF
+
j
*
hiddenSize
+
h
]
=
-
1
;
}
}
}
}
else
if
(
packOutput
==
false
and
t
<
maxFLen
and
u
<
maxGLen
and
hOffset
+
h
<
hiddenSize
){
// Only need to ensure the finity in normal mode
#pragma unroll
for
(
int
i
=
0
;
i
<
tileF
;
++
i
){
if
(
t
+
i
<
maxFLen
){
#pragma unroll
for
(
int
j
=
0
;
j
<
tileG
;
++
j
){
if
(
u
+
j
<
maxGLen
)
mySum
[
i
*
strideF
+
j
*
hiddenSize
+
h
]
=
-
1
;
}
}
}
}
}
/*
Bwd operation (reduction) on one input tensor. Since the operation performed for the two input
tensors are exactly the same, only one kernel is needed, and the different indexing offsets
and strides are handled by OffsetCalBwd.
When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
non-packed form.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template
<
typename
scalar_t
,
typename
acc_t
,
class
OffsetCal
,
bool
masked
>
__device__
void
transducer_joint_single_backward
(
const
scalar_t
*
grad
,
const
uint8_t
*
mask
,
const
int
*
fLen
,
const
int
*
gLen
,
const
int64_t
*
batchOffset
,
int64_t
maxFLen
,
int64_t
maxGLen
,
int64_t
hiddenSize
,
bool
packOutput
,
bool
bwdFasterDim
,
// whether bwd on the faster moving dimension (u)
float
scale
,
scalar_t
*
inGrad
,
int
yBlockOffset
=
0
)
{
const
int
batch
=
blockIdx
.
z
;
// For the second input tensor, this offset need to be subtracted because the first yBlockOffset
// sets of thread blocks are for the first input tensor.
const
int
x
=
blockIdx
.
y
-
yBlockOffset
;
const
int
hOffset
=
blockIdx
.
x
*
C10_WARP_SIZE
;
const
int
wid
=
threadIdx
.
y
;
const
int
lid
=
threadIdx
.
x
;
const
int
numWarp
=
blockDim
.
y
;
extern
__shared__
char
smem8
[];
auto
smem
=
reinterpret_cast
<
acc_t
*>
(
smem8
);
OffsetCal
offsetCal
(
batch
,
batchOffset
,
fLen
,
gLen
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
bwdFasterDim
);
const
auto
maxXLen
=
offsetCal
.
getMaxXLen
();
const
auto
myXLen
=
offsetCal
.
getMyXLen
();
const
auto
myYLen
=
offsetCal
.
getMyYLen
();
scalar_t
*
myInGrad
=
inGrad
+
batch
*
maxXLen
*
hiddenSize
+
x
*
hiddenSize
+
hOffset
;
if
(
x
<
myXLen
){
const
auto
myBatchOffset
=
offsetCal
.
getBatchOffset
();
const
auto
strideX
=
offsetCal
.
getStrideX
();
const
auto
strideY
=
offsetCal
.
getStrideY
();
const
scalar_t
*
myGrad
=
grad
+
myBatchOffset
+
x
*
strideX
+
hOffset
;
const
uint8_t
*
myMask
=
masked
?
mask
+
myBatchOffset
+
x
*
strideX
+
hOffset
:
nullptr
;
// Each warp reduces numYPerWarp "y" first
acc_t
warpSum
=
0
;
auto
numYPerWarp
=
(
myYLen
+
numWarp
-
1
)
/
numWarp
;
#pragma unroll
for
(
int
warpY
=
0
;
warpY
<
numYPerWarp
;
++
warpY
){
auto
y
=
wid
*
numYPerWarp
+
warpY
;
if
(
y
<
myYLen
and
(
hOffset
+
lid
)
<
hiddenSize
)
if
(
masked
)
warpSum
+=
static_cast
<
acc_t
>
(
myGrad
[
y
*
strideY
+
lid
])
*
myMask
[
y
*
strideY
+
lid
]
*
scale
;
else
warpSum
+=
myGrad
[
y
*
strideY
+
lid
];
}
// transpose partial sum in SMEM and reduce further using warpReduce
smem
[
lid
*
numWarp
+
wid
]
=
warpSum
;
__syncthreads
();
auto
sum
=
smem
[
wid
*
C10_WARP_SIZE
+
lid
];
sum
=
warpReduce
(
sum
,
numWarp
);
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if
(
hOffset
+
wid
*
C10_WARP_SIZE
/
numWarp
+
lid
/
numWarp
<
hiddenSize
){
if
(
lid
%
numWarp
==
0
){
myInGrad
[
wid
*
C10_WARP_SIZE
/
numWarp
+
lid
/
numWarp
]
=
sum
;
}
}
}
else
if
(
wid
==
0
and
hOffset
+
lid
<
hiddenSize
){
// Need to ensure the grad is zero for don't care region
myInGrad
[
lid
]
=
0
;
}
}
/*
Actual bwd (reduction) kernel get launched.
Call transducer_joint_single_backward twice on two input tensors.
The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template
<
typename
scalar_t
,
typename
acc_t
,
class
OffsetCal
,
bool
masked
>
__global__
void
transducer_joint_combined_backward
(
const
scalar_t
*
grad
,
const
uint8_t
*
mask
,
const
int
*
fLen
,
const
int
*
gLen
,
const
int64_t
*
batchOffset
,
int64_t
maxFLen
,
int64_t
maxGLen
,
int64_t
hiddenSize
,
bool
packOutput
,
float
scale
,
scalar_t
*
fGrad
,
scalar_t
*
gGrad
)
{
if
(
blockIdx
.
y
<
maxFLen
){
transducer_joint_single_backward
<
scalar_t
,
acc_t
,
OffsetCal
,
masked
>
(
grad
,
mask
,
fLen
,
gLen
,
batchOffset
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
false
,
scale
,
fGrad
);
}
else
{
transducer_joint_single_backward
<
scalar_t
,
acc_t
,
OffsetCal
,
masked
>
(
grad
,
mask
,
fLen
,
gLen
,
batchOffset
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
true
,
scale
,
gGrad
,
maxFLen
);
}
}
/*
Vectorized version of transducer_joint_single_backward
Doing exact same operation as transducer_joint_single_backward except the load and store are
vectorized.
When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
non-packed form.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template
<
typename
scalar_t
,
typename
acc_t
,
typename
vec_t
,
int
V
,
class
OffsetCal
,
bool
masked
>
__device__
void
transducer_joint_single_vec_backward
(
const
scalar_t
*
grad
,
const
uint8_t
*
mask
,
const
int
*
fLen
,
const
int
*
gLen
,
const
int64_t
*
batchOffset
,
int64_t
maxFLen
,
int64_t
maxGLen
,
int64_t
hiddenSize
,
bool
packOutput
,
bool
bwdFasterDim
,
float
scale
,
scalar_t
*
inGrad
,
int
yBlockOffset
=
0
){
const
int
batch
=
blockIdx
.
z
;
const
int
x
=
blockIdx
.
y
-
yBlockOffset
;
const
int
hOffset
=
blockIdx
.
x
*
C10_WARP_SIZE
*
V
;
const
int
wid
=
threadIdx
.
y
;
const
int
lid
=
threadIdx
.
x
;
const
int
numWarp
=
blockDim
.
y
;
// Figure out the vectorization type for mask
using
mvec_t
=
mvec_type
<
V
>
;
OffsetCal
offsetCal
(
batch
,
batchOffset
,
fLen
,
gLen
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
bwdFasterDim
);
const
auto
maxXLen
=
offsetCal
.
getMaxXLen
();
const
auto
myXLen
=
offsetCal
.
getMyXLen
();
const
auto
myYLen
=
offsetCal
.
getMyYLen
();
scalar_t
*
myInGrad
=
inGrad
+
batch
*
maxXLen
*
hiddenSize
+
x
*
hiddenSize
+
hOffset
;
extern
__shared__
char
smem8
[];
auto
smem
=
reinterpret_cast
<
acc_t
*>
(
smem8
);
acc_t
warpSum
[
V
];
scalar_t
inBuffer
[
V
];
uint8_t
maskBuffer
[
V
];
scalar_t
outBuffer
[
V
];
auto
myInGradVec
=
reinterpret_cast
<
vec_t
*>
(
myInGrad
);
auto
outBufferVec
=
reinterpret_cast
<
vec_t
*>
(
outBuffer
);
if
(
x
<
myXLen
){
const
auto
myBatchOffset
=
offsetCal
.
getBatchOffset
();
const
auto
strideX
=
offsetCal
.
getStrideX
();
const
auto
strideY
=
offsetCal
.
getStrideY
();
const
scalar_t
*
myGrad
=
grad
+
myBatchOffset
+
x
*
strideX
+
hOffset
;
const
uint8_t
*
myMask
=
masked
?
mask
+
myBatchOffset
+
x
*
strideX
+
hOffset
:
nullptr
;
for
(
int
i
=
0
;
i
<
V
;
++
i
)
warpSum
[
i
]
=
0
;
// Each warp reduces numYPerWarp "y" first
auto
numYPerWarp
=
(
myYLen
+
numWarp
-
1
)
/
numWarp
;
for
(
int
warpY
=
0
;
warpY
<
numYPerWarp
;
++
warpY
){
auto
y
=
wid
*
numYPerWarp
+
warpY
;
auto
myGradVec
=
reinterpret_cast
<
vec_t
const
*>
(
myGrad
+
y
*
strideY
);
auto
myMaskVec
=
masked
?
reinterpret_cast
<
mvec_t
const
*>
(
myMask
+
y
*
strideY
)
:
nullptr
;
auto
inBufferVec
=
reinterpret_cast
<
vec_t
*>
(
inBuffer
);
auto
maskBufferVec
=
reinterpret_cast
<
mvec_t
*>
(
maskBuffer
);
if
(
hOffset
+
lid
*
V
<
hiddenSize
and
y
<
myYLen
){
*
inBufferVec
=
myGradVec
[
lid
];
// vectorized load
if
(
masked
){
*
maskBufferVec
=
myMaskVec
[
lid
];
#pragma unroll
for
(
int
i
=
0
;
i
<
V
;
++
i
)
warpSum
[
i
]
+=
static_cast
<
acc_t
>
(
inBuffer
[
i
])
*
maskBuffer
[
i
]
*
scale
;
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
V
;
++
i
)
warpSum
[
i
]
+=
inBuffer
[
i
];
}
}
}
// transpose partial sum in SMEM and reduce further using warpReduce
for
(
int
i
=
0
;
i
<
V
;
++
i
){
smem
[
lid
*
numWarp
+
wid
]
=
warpSum
[
i
];
__syncthreads
();
auto
sum
=
smem
[
wid
*
C10_WARP_SIZE
+
lid
];
if
(
hOffset
+
(
wid
*
C10_WARP_SIZE
/
numWarp
)
*
V
<
hiddenSize
){
sum
=
warpReduce
(
sum
,
numWarp
);
if
(
lid
%
numWarp
==
0
){
outBuffer
[
i
]
=
sum
;
}
}
__syncthreads
();
}
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if
(
lid
%
numWarp
==
0
and
hOffset
+
(
wid
*
C10_WARP_SIZE
/
numWarp
+
lid
/
numWarp
)
*
V
<
hiddenSize
)
myInGradVec
[
wid
*
C10_WARP_SIZE
/
numWarp
+
lid
/
numWarp
]
=
*
outBufferVec
;
}
else
if
(
wid
==
0
and
hOffset
+
lid
*
V
<
hiddenSize
){
// Need to ensure the grad is zero for don't care region
myInGradVec
[
lid
]
=
0
;
}
}
/*
Vecotrized version of transducer_joint_combined_backward
Call transducer_joint_single_vec_backward twice on two input tensors.
The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template
<
typename
scalar_t
,
typename
acc_t
,
typename
vec_t
,
int
V
,
class
OffsetCal
,
bool
masked
>
__global__
void
transducer_joint_combined_vec_backward
(
const
scalar_t
*
grad
,
const
uint8_t
*
mask
,
const
int
*
fLen
,
const
int
*
gLen
,
const
int64_t
*
batchOffset
,
int64_t
maxFLen
,
int64_t
maxGLen
,
int64_t
hiddenSize
,
bool
packOutput
,
float
scale
,
scalar_t
*
fGrad
,
scalar_t
*
gGrad
)
{
if
(
blockIdx
.
y
<
maxFLen
){
transducer_joint_single_vec_backward
<
scalar_t
,
acc_t
,
vec_t
,
V
,
OffsetCal
,
masked
>
(
grad
,
mask
,
fLen
,
gLen
,
batchOffset
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
false
,
scale
,
fGrad
);
}
else
{
transducer_joint_single_vec_backward
<
scalar_t
,
acc_t
,
vec_t
,
V
,
OffsetCal
,
masked
>
(
grad
,
mask
,
fLen
,
gLen
,
batchOffset
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
true
,
scale
,
gGrad
,
maxFLen
);
}
}
std
::
vector
<
torch
::
Tensor
>
transducer_joint_cuda_forward
(
torch
::
Tensor
f
,
torch
::
Tensor
g
,
torch
::
Tensor
fLen
,
torch
::
Tensor
gLen
,
torch
::
Tensor
batchOffset
,
int64_t
packedBatch
,
int
opt
,
bool
packOutput
,
bool
relu
,
bool
dropout
,
float
dropoutProb
,
int
tileSize
){
auto
tensorOpt
=
f
.
options
();
auto
dtype
=
f
.
scalar_type
();
const
auto
batchSize
=
f
.
size
(
0
);
const
auto
maxFLen
=
f
.
size
(
1
);
const
auto
maxGLen
=
g
.
size
(
1
);
const
auto
hiddenSize
=
f
.
size
(
2
);
bool
masked
=
dropout
or
relu
;
int64_t
*
batchOffsetPtr
=
nullptr
;
torch
::
Tensor
sum
,
mask
;
auto
maskOpt
=
tensorOpt
.
dtype
(
torch
::
kUInt8
);
if
(
!
packOutput
){
sum
=
torch
::
empty
({
batchSize
,
maxFLen
,
maxGLen
,
hiddenSize
},
tensorOpt
);
batchOffsetPtr
=
nullptr
;
if
(
masked
)
mask
=
torch
::
empty
({
batchSize
,
maxFLen
,
maxGLen
,
hiddenSize
},
maskOpt
);
}
else
{
sum
=
torch
::
empty
({
packedBatch
,
hiddenSize
},
tensorOpt
);
batchOffsetPtr
=
batchOffset
.
data_ptr
<
int64_t
>
();
if
(
masked
)
mask
=
torch
::
empty
({
packedBatch
,
hiddenSize
},
maskOpt
);
}
uint8_t
*
maskPtr
=
masked
?
mask
.
data_ptr
<
uint8_t
>
()
:
nullptr
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
opt
==
0
or
opt
==
1
,
"Got an invalid optimization level "
,
opt
);
// Simple heuristics
const
int
numThread
=
std
::
min
(
128
,
(
static_cast
<
int
>
(
hiddenSize
)
+
C10_WARP_SIZE
-
1
)
/
C10_WARP_SIZE
*
C10_WARP_SIZE
);
if
(
opt
==
0
){
// vanilla kernel
const
int
threads
=
numThread
;
const
dim3
blocks
(
maxGLen
,
maxFLen
,
batchSize
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
dtype
,
"transducer_joint_forward"
,
([
&
]
{
transducer_joint_forward
<
scalar_t
,
OffsetCalFwd
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
f
.
data_ptr
<
scalar_t
>
(),
g
.
data_ptr
<
scalar_t
>
(),
fLen
.
data_ptr
<
int
>
(),
gLen
.
data_ptr
<
int
>
(),
batchOffsetPtr
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
sum
.
data_ptr
<
scalar_t
>
());
}));
}
if
(
opt
==
1
){
// tiled version. For simplicity, assume tileF == tileG, even though the kernel can
// support more general cases.
const
int
threads
=
numThread
;
const
int
hiddenPerBlock
=
numThread
;
const
int
hiddenBlock
=
(
hiddenSize
+
hiddenPerBlock
-
1
)
/
hiddenPerBlock
;
const
dim3
blocks
(
(
maxGLen
+
tileSize
-
1
)
/
tileSize
*
hiddenBlock
,
(
maxFLen
+
tileSize
-
1
)
/
tileSize
,
batchSize
);
TORCH_CHECK
(
tileSize
==
1
or
tileSize
==
2
or
tileSize
==
4
,
"Expected tileSize to be in [1, 2, 4], but got "
,
tileSize
);
at
::
PhiloxCudaState
rng_engine_inputs
;
if
(
masked
){
// set up PRG when the input is masked. rng_engine_inputs will be used as a space filler
// for non-masked calls.
// Therefore no need to initialize.
c10
::
optional
<
at
::
Generator
>
gen_
;
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
// counterOffset records how many cuRAND calls each thread makes. For a tiled kernel,
// each thread processes tileF * tileG output elements.
int64_t
counterOffset
=
tileSize
*
tileSize
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
rng_engine_inputs
=
gen
->
philox_cuda_state
(
counterOffset
);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
dtype
,
"transducer_joint_forward"
,
([
&
]
{
void
(
*
kernel
)(
const
scalar_t
*
,
const
scalar_t
*
,
const
int
*
,
const
int
*
,
const
int64_t
*
,
int64_t
,
int64_t
,
int64_t
,
int64_t
,
bool
,
bool
,
bool
,
float
,
at
::
PhiloxCudaState
,
scalar_t
*
,
uint8_t
*
);
if
(
masked
){
switch
(
tileSize
){
case
2
:
kernel
=
&
transducer_joint_tiled_forward
<
scalar_t
,
2
,
2
,
4
,
OffsetCalFwd
,
true
>
;
break
;
case
4
:
kernel
=
&
transducer_joint_tiled_forward
<
scalar_t
,
4
,
4
,
4
,
OffsetCalFwd
,
true
>
;
break
;
}
}
else
{
switch
(
tileSize
){
case
1
:
kernel
=
&
transducer_joint_tiled_forward
<
scalar_t
,
1
,
1
,
4
,
OffsetCalFwd
,
false
>
;
break
;
case
2
:
kernel
=
&
transducer_joint_tiled_forward
<
scalar_t
,
2
,
2
,
4
,
OffsetCalFwd
,
false
>
;
break
;
case
4
:
kernel
=
&
transducer_joint_tiled_forward
<
scalar_t
,
4
,
4
,
4
,
OffsetCalFwd
,
false
>
;
break
;
}
}
kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
f
.
data_ptr
<
scalar_t
>
(),
g
.
data_ptr
<
scalar_t
>
(),
fLen
.
data_ptr
<
int
>
(),
gLen
.
data_ptr
<
int
>
(),
batchOffsetPtr
,
maxFLen
,
maxGLen
,
hiddenSize
,
hiddenPerBlock
,
packOutput
,
relu
,
dropout
,
1.0
f
-
dropoutProb
,
rng_engine_inputs
,
sum
.
data_ptr
<
scalar_t
>
(),
maskPtr
);
}));
}
C10_CUDA_CHECK
(
cudaGetLastError
());
if
(
masked
)
return
{
sum
,
mask
};
else
return
{
sum
};
}
std
::
vector
<
torch
::
Tensor
>
transducer_joint_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
in
,
torch
::
Tensor
fLen
,
torch
::
Tensor
gLen
,
torch
::
Tensor
batchOffset
,
int
maxFLen
,
int
maxGLen
,
bool
packOutput
,
float
scale
){
auto
grad
=
in
[
0
];
bool
masked
=
(
in
.
size
()
==
2
);
uint8_t
*
maskPtr
=
masked
?
in
[
1
].
data_ptr
<
uint8_t
>
()
:
nullptr
;
auto
tensorOpt
=
grad
.
options
();
auto
dtype
=
grad
.
scalar_type
();
const
int
batchSize
=
fLen
.
size
(
0
);
const
int
hiddenSize
=
grad
.
size
(
-
1
);
const
auto
deviceProperties
=
at
::
cuda
::
getCurrentDeviceProperties
();
const
int
maxNumWarp
=
deviceProperties
->
maxThreadsPerBlock
/
C10_WARP_SIZE
;
torch
::
Tensor
fGrad
=
torch
::
empty
({
batchSize
,
maxFLen
,
hiddenSize
},
tensorOpt
);
torch
::
Tensor
gGrad
=
torch
::
empty
({
batchSize
,
maxGLen
,
hiddenSize
},
tensorOpt
);
int64_t
*
batchOffsetPtr
=
(
!
packOutput
)
?
nullptr
:
batchOffset
.
data_ptr
<
int64_t
>
();
// The number "y" I would like each thread to work on
const
int
workPerThread
=
32
;
// Since the bwd for f and g have the same thread block size, we need to use the max of the two.
int
numWarp
=
largestPowerOfTwo
((
std
::
max
(
maxFLen
,
maxGLen
)
+
workPerThread
-
1
)
/
workPerThread
);
// Would like to have at least 2 warps
numWarp
=
std
::
max
(
2
,
numWarp
);
// cap on the maximum number of warps allowed
numWarp
=
std
::
min
(
maxNumWarp
,
numWarp
);
// Need smem for transposing the partial sum. The partial sum is in a matrix of the shape
// numWarp x warpSize
const
int
smemSize
=
numWarp
*
C10_WARP_SIZE
;
const
dim3
threads
(
C10_WARP_SIZE
,
numWarp
,
1
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
dtype
,
"transducer_joint_cuda_backward_kernel"
,
([
&
]
{
auto
gradPtr
=
grad
.
data_ptr
<
scalar_t
>
();
auto
fLenPtr
=
fLen
.
data_ptr
<
int
>
();
auto
gLenPtr
=
gLen
.
data_ptr
<
int
>
();
auto
fGradPtr
=
fGrad
.
data_ptr
<
scalar_t
>
();
auto
gGradPtr
=
gGrad
.
data_ptr
<
scalar_t
>
();
// resolve the acc_t type
using
acc_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
vec_t
=
uint64_t
;
constexpr
int
vectFactor
=
sizeof
(
vec_t
)
/
sizeof
(
scalar_t
);
constexpr
int
vecAlignment
=
std
::
alignment_of
<
vec_t
>::
value
;
// if all input and output tensors meet the alignment requirement
bool
memAlign
=
(
reinterpret_cast
<
uint64_t
>
(
gradPtr
)
%
vecAlignment
==
0
)
and
(
reinterpret_cast
<
uint64_t
>
(
fGradPtr
)
%
vecAlignment
==
0
)
and
(
reinterpret_cast
<
uint64_t
>
(
gGradPtr
)
%
vecAlignment
==
0
);
if
(
vectFactor
>
1
and
hiddenSize
%
vectFactor
==
0
and
memAlign
){
// If vectorization helps and the alignment requirement is met, use the vectorized
// kernel. For simplicity, hiddenSize needs to be a multiple vecFactor.
const
dim3
blocks
(
(
hiddenSize
+
C10_WARP_SIZE
*
vectFactor
-
1
)
/
(
C10_WARP_SIZE
*
vectFactor
),
maxFLen
+
maxGLen
,
batchSize
);
if
(
masked
){
transducer_joint_combined_vec_backward
<
scalar_t
,
acc_t
,
vec_t
,
vectFactor
,
OffsetCalBwd
,
true
>
<<<
blocks
,
threads
,
smemSize
*
sizeof
(
acc_t
)
>>>
(
gradPtr
,
maskPtr
,
fLenPtr
,
gLenPtr
,
batchOffsetPtr
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
scale
,
fGradPtr
,
gGradPtr
);
}
else
{
transducer_joint_combined_vec_backward
<
scalar_t
,
acc_t
,
vec_t
,
vectFactor
,
OffsetCalBwd
,
false
>
<<<
blocks
,
threads
,
smemSize
*
sizeof
(
acc_t
)
>>>
(
gradPtr
,
maskPtr
,
fLenPtr
,
gLenPtr
,
batchOffsetPtr
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
scale
,
fGradPtr
,
gGradPtr
);
}
}
else
{
const
dim3
blocks
((
hiddenSize
+
C10_WARP_SIZE
-
1
)
/
C10_WARP_SIZE
,
maxFLen
+
maxGLen
,
batchSize
);
if
(
masked
){
transducer_joint_combined_backward
<
scalar_t
,
acc_t
,
OffsetCalBwd
,
true
>
<<<
blocks
,
threads
,
smemSize
*
sizeof
(
acc_t
)
>>>
(
gradPtr
,
maskPtr
,
fLenPtr
,
gLenPtr
,
batchOffsetPtr
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
scale
,
fGradPtr
,
gGradPtr
);
}
else
{
transducer_joint_combined_backward
<
scalar_t
,
acc_t
,
OffsetCalBwd
,
false
>
<<<
blocks
,
threads
,
smemSize
*
sizeof
(
acc_t
)
>>>
(
gradPtr
,
maskPtr
,
fLenPtr
,
gLenPtr
,
batchOffsetPtr
,
maxFLen
,
maxGLen
,
hiddenSize
,
packOutput
,
scale
,
fGradPtr
,
gGradPtr
);
}
}
}));
return
{
fGrad
,
gGrad
};
}
apex/contrib/csrc/transducer/transducer_loss.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <torch/extension.h>
#include <vector>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
transducer_loss_cuda_forward
(
torch
::
Tensor
x
,
torch
::
Tensor
label
,
torch
::
Tensor
audLen
,
torch
::
Tensor
txtLen
,
torch
::
Tensor
batchOffset
,
int
maxFLen
,
int
blankIdx
,
int
opt
,
bool
packedInput
);
torch
::
Tensor
transducer_loss_cuda_backward
(
torch
::
Tensor
x
,
torch
::
Tensor
lossGrad
,
torch
::
Tensor
alpha
,
torch
::
Tensor
beta
,
torch
::
Tensor
audLen
,
torch
::
Tensor
txtLen
,
torch
::
Tensor
label
,
torch
::
Tensor
batchOffset
,
int
maxFLen
,
int
blankIdx
,
int
opt
,
bool
fuseSoftmaxBackward
,
bool
packedInput
);
std
::
vector
<
torch
::
Tensor
>
transducer_loss_forward
(
torch
::
Tensor
x
,
torch
::
Tensor
label
,
torch
::
Tensor
fLen
,
torch
::
Tensor
yLen
,
torch
::
Tensor
batchOffset
,
int
maxFLen
,
int
blankIdx
,
int
opt
,
bool
packedInput
)
{
CHECK_INPUT
(
x
);
CHECK_INPUT
(
label
);
CHECK_INPUT
(
fLen
);
CHECK_INPUT
(
yLen
);
if
(
packedInput
)
CHECK_INPUT
(
batchOffset
);
return
transducer_loss_cuda_forward
(
x
,
label
,
fLen
,
yLen
,
batchOffset
,
maxFLen
,
blankIdx
,
opt
,
packedInput
);
}
torch
::
Tensor
transducer_loss_backward
(
torch
::
Tensor
x
,
torch
::
Tensor
lossGrad
,
torch
::
Tensor
alpha
,
torch
::
Tensor
beta
,
torch
::
Tensor
fLen
,
torch
::
Tensor
yLen
,
torch
::
Tensor
label
,
torch
::
Tensor
batchOffset
,
int
maxFLen
,
int
blankIdx
,
int
opt
,
bool
fuseSoftmaxBackward
,
bool
packedInput
){
CHECK_INPUT
(
x
);
CHECK_INPUT
(
label
);
CHECK_INPUT
(
lossGrad
);
CHECK_INPUT
(
alpha
);
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
fLen
);
CHECK_INPUT
(
yLen
);
if
(
packedInput
)
CHECK_INPUT
(
batchOffset
);
return
transducer_loss_cuda_backward
(
x
,
lossGrad
,
alpha
,
beta
,
fLen
,
yLen
,
label
,
batchOffset
,
maxFLen
,
blankIdx
,
opt
,
fuseSoftmaxBackward
,
packedInput
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
transducer_loss_forward
,
"transducer loss forward (CUDA)"
);
m
.
def
(
"backward"
,
&
transducer_loss_backward
,
"transducer loss backward (CUDA)"
);
}
apex/contrib/csrc/transducer/transducer_loss_kernel.cu
deleted
100755 → 0
View file @
2a4864d5
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
logSumExp
(
scalar_t
a
,
scalar_t
b
)
{
// standard log-sum-exp trick is used here to provide better numerical stability
return
(
a
>=
b
)
?
a
+
std
::
log1p
(
exp
(
b
-
a
))
:
b
+
std
::
log1p
(
exp
(
a
-
b
));
}
// Vanilla transducer loss function (i.e. forward-backward algorithm)
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template
<
typename
scalar_t
,
typename
acc_t
>
__global__
void
transducer_loss_forward
(
const
scalar_t
*
x
,
const
int
*
label
,
const
int
*
audLen
,
const
int
*
txtLen
,
const
int64_t
*
batchOffset
,
int64_t
dictSize
,
// 64-bit indexing for data tensor
int64_t
blankIdx
,
int64_t
maxFLen
,
int64_t
maxGLen
,
bool
packedInput
,
acc_t
*
alpha
,
acc_t
*
beta
,
scalar_t
*
loss
)
{
const
int
batch
=
blockIdx
.
y
;
const
int
tid
=
threadIdx
.
x
;
const
auto
myFLen
=
audLen
[
batch
];
// Note that start of the sentence is added as 1 here
const
auto
myGLen
=
txtLen
[
batch
]
+
1
;
const
auto
myLabel
=
label
+
batch
*
(
maxGLen
-
1
);
const
int64_t
myBatchOffset
=
packedInput
?
(
batch
==
0
?
0
:
batchOffset
[
batch
-
1
])
:
batch
*
maxFLen
*
maxGLen
;
const
int64_t
myStrideT
=
packedInput
?
myGLen
:
maxGLen
;
const
scalar_t
*
myX
=
x
+
myBatchOffset
*
dictSize
;
int
u
=
tid
;
if
(
blockIdx
.
x
==
0
){
// alpha path
acc_t
*
myAlpha
=
alpha
+
batch
*
maxFLen
*
maxGLen
;
if
(
u
==
0
)
myAlpha
[
0
]
=
0
;
__syncthreads
();
for
(
int64_t
step
=
1
;
step
<
myFLen
+
myGLen
-
1
;
++
step
){
// Move along the diagonal wavefront to leverage available parallelism
for
(
u
=
tid
;
u
<
myGLen
;
u
+=
blockDim
.
x
){
int64_t
t
=
step
-
u
;
if
(
t
>=
0
and
t
<
myFLen
and
u
>=
0
and
u
<
myGLen
){
// Eq(16) in [1]
if
(
u
==
0
){
// alpha(t, u) = alpha(t-1, u) * null(t-1, u)
myAlpha
[
t
*
maxGLen
+
u
]
=
myAlpha
[(
t
-
1
)
*
maxGLen
]
+
myX
[((
t
-
1
)
*
myStrideT
)
*
dictSize
+
blankIdx
];
}
else
if
(
t
==
0
){
// alpha(t, u-1) = alpha(t, u-1) * y(t, u-1)
myAlpha
[
u
]
=
myAlpha
[
u
-
1
]
+
myX
[(
u
-
1
)
*
dictSize
+
myLabel
[
u
-
1
]];
}
else
{
// alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1)
acc_t
current
=
myAlpha
[(
t
-
1
)
*
maxGLen
+
u
]
+
myX
[((
t
-
1
)
*
myStrideT
+
u
)
*
dictSize
+
blankIdx
];
acc_t
next
=
myAlpha
[
t
*
maxGLen
+
u
-
1
]
+
myX
[(
t
*
myStrideT
+
u
-
1
)
*
dictSize
+
myLabel
[
u
-
1
]];
myAlpha
[
t
*
maxGLen
+
u
]
=
logSumExp
(
next
,
current
);
}
}
}
__syncthreads
();
}
}
else
if
(
blockIdx
.
x
==
1
){
// beta path
acc_t
*
myBeta
=
beta
+
batch
*
maxFLen
*
maxGLen
;
if
(
u
==
0
){
myBeta
[(
myFLen
-
1
)
*
maxGLen
+
myGLen
-
1
]
=
myX
[((
myFLen
-
1
)
*
myStrideT
+
myGLen
-
1
)
*
dictSize
+
blankIdx
];
}
__syncthreads
();
for
(
int64_t
step
=
myFLen
+
myGLen
-
3
;
step
>=
0
;
--
step
){
for
(
u
=
tid
;
u
<
myGLen
;
u
+=
blockDim
.
x
){
int64_t
t
=
step
-
u
;
if
(
t
>=
0
and
t
<
myFLen
and
u
>=
0
and
u
<
myGLen
){
// Eq(18) in [1]
if
(
u
==
myGLen
-
1
){
// beta(t, u) = beta(t+1, u) * null(t, u)
myBeta
[
t
*
maxGLen
+
u
]
=
myBeta
[(
t
+
1
)
*
maxGLen
+
u
]
+
myX
[(
t
*
myStrideT
+
u
)
*
dictSize
+
blankIdx
];
}
else
if
(
t
==
myFLen
-
1
){
// beta(t, u) = beta(t, u+1) * y(t, u)
myBeta
[
t
*
maxGLen
+
u
]
=
myBeta
[
t
*
maxGLen
+
u
+
1
]
+
myX
[(
t
*
myStrideT
+
u
)
*
dictSize
+
myLabel
[
u
]];
}
else
{
// beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u)
acc_t
current
=
myBeta
[(
t
+
1
)
*
maxGLen
+
u
]
+
myX
[(
t
*
myStrideT
+
u
)
*
dictSize
+
blankIdx
];
acc_t
next
=
myBeta
[
t
*
maxGLen
+
u
+
1
]
+
myX
[(
t
*
myStrideT
+
u
)
*
dictSize
+
myLabel
[
u
]];
myBeta
[
t
*
maxGLen
+
u
]
=
logSumExp
(
next
,
current
);
}
}
}
__syncthreads
();
}
if
(
tid
==
0
)
loss
[
batch
]
=
-
myBeta
[
0
];
}
}
// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization.
// Compared to the vanilla version, there are two optimizations:
// 1. load x in batch through loop unrolling to reduce the latency.
// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next.
// For simplicity, this kernel currently only supports U <= maxThread, which should be the common
// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template
<
typename
scalar_t
,
typename
acc_t
,
int
batchLdSize
>
__global__
void
transducer_loss_batch_load_forward
(
const
scalar_t
*
x
,
const
int
*
label
,
const
int
*
audLen
,
const
int
*
txtLen
,
const
int64_t
*
batchOffset
,
int64_t
dictSize
,
int64_t
blankIdx
,
int64_t
maxFLen
,
int64_t
maxGLen
,
bool
packedInput
,
acc_t
*
alpha
,
acc_t
*
beta
,
scalar_t
*
loss
)
{
const
int
batch
=
blockIdx
.
y
;
int
u
=
threadIdx
.
x
;
const
auto
myFLen
=
audLen
[
batch
];
const
auto
myGLen
=
txtLen
[
batch
]
+
1
;
const
int64_t
myBatchOffset
=
packedInput
?
(
batch
==
0
?
0
:
batchOffset
[
batch
-
1
])
:
batch
*
maxFLen
*
maxGLen
;
const
int64_t
myStrideT
=
packedInput
?
myGLen
:
maxGLen
;
const
scalar_t
*
myX
=
x
+
myBatchOffset
*
dictSize
;
scalar_t
next
[
batchLdSize
],
current
[
batchLdSize
];
extern
__shared__
char
smem8
[];
auto
smem
=
reinterpret_cast
<
acc_t
*>
(
smem8
);
if
(
blockIdx
.
x
==
0
){
// alpha path
acc_t
*
myAlpha
=
alpha
+
batch
*
maxFLen
*
maxGLen
;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t
*
const
sharedAlpha
[
2
]
=
{
smem
,
smem
+
maxGLen
};
sharedAlpha
[
0
][
u
]
=
0
;
__syncthreads
();
if
(
u
==
0
)
myAlpha
[
0
]
=
0
;
auto
myAlphaLabel
=
(
u
==
0
)
?
0
:
label
[
batch
*
(
maxGLen
-
1
)
+
u
-
1
];
// register used to pass value to the next step for the same thread
acc_t
prvStepAlpha
=
0
;
for
(
int64_t
step
=
1
;
step
<
myFLen
+
myGLen
-
1
+
batchLdSize
;
step
+=
batchLdSize
){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X through loop unrolling
#pragma unroll
for
(
int
i
=
0
;
i
<
batchLdSize
;
++
i
){
if
(
step
+
i
<
myFLen
+
myGLen
-
1
){
// index computing
int64_t
t
=
step
+
i
-
u
;
int64_t
currentId
=
((
t
-
1
)
*
myStrideT
+
u
)
*
dictSize
+
blankIdx
;
int64_t
nextId
=
(
t
*
myStrideT
+
u
-
1
)
*
dictSize
+
myAlphaLabel
;
// main loading loop
if
(
t
>=
0
and
t
<
myFLen
and
u
>=
0
and
u
<
myGLen
){
if
(
u
==
0
){
current
[
i
]
=
myX
[
currentId
];
}
else
if
(
t
==
0
){
next
[
i
]
=
myX
[
nextId
];
}
else
{
current
[
i
]
=
myX
[
currentId
];
next
[
i
]
=
myX
[
nextId
];
}
}
}
}
// main computing loop
for
(
int
i
=
0
;
i
<
batchLdSize
;
++
i
){
// swap the pointer for double buffering
auto
sharedAlphaRd
=
sharedAlpha
[(
step
+
i
-
1
)
%
2
];
auto
sharedAlphaWr
=
sharedAlpha
[(
step
+
i
)
%
2
];
if
(
step
+
i
<
myFLen
+
myGLen
-
1
){
int64_t
t
=
step
+
i
-
u
;
if
(
t
>=
0
and
t
<
myFLen
and
u
>=
0
and
u
<
myGLen
){
// Eq(16) in [1]
if
(
u
==
0
)
prvStepAlpha
=
prvStepAlpha
+
current
[
i
];
else
if
(
t
==
0
)
prvStepAlpha
=
sharedAlphaRd
[
u
-
1
]
+
next
[
i
];
else
prvStepAlpha
=
logSumExp
(
prvStepAlpha
+
current
[
i
],
sharedAlphaRd
[
u
-
1
]
+
next
[
i
]);
sharedAlphaWr
[
u
]
=
prvStepAlpha
;
myAlpha
[
t
*
maxGLen
+
u
]
=
prvStepAlpha
;
}
}
__syncthreads
();
}
}
}
else
if
(
blockIdx
.
x
==
1
){
// beta path
acc_t
*
myBeta
=
beta
+
batch
*
maxFLen
*
maxGLen
;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t
*
const
sharedBeta
[
2
]
=
{
smem
,
smem
+
maxGLen
};
sharedBeta
[
0
][
u
]
=
myX
[((
myFLen
-
1
)
*
myStrideT
+
myGLen
-
1
)
*
dictSize
+
blankIdx
];
__syncthreads
();
auto
myBetaLabel
=
(
u
==
maxGLen
-
1
)
?
0
:
label
[
batch
*
(
maxGLen
-
1
)
+
u
];
// register used to pass value to the next step for the same thread
acc_t
prvStepBeta
=
myX
[((
myFLen
-
1
)
*
myStrideT
+
myGLen
-
1
)
*
dictSize
+
blankIdx
];
if
(
u
==
0
)
myBeta
[(
myFLen
-
1
)
*
maxGLen
+
myGLen
-
1
]
=
prvStepBeta
;
for
(
int64_t
step
=
1
;
step
<
myFLen
+
myGLen
-
1
;
step
+=
batchLdSize
){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X
#pragma unroll
for
(
int
i
=
0
;
i
<
batchLdSize
;
++
i
){
if
(
step
+
i
<
myFLen
+
myGLen
-
1
){
// index computing
int64_t
t
=
myFLen
+
myGLen
-
(
step
+
i
)
-
2
-
u
;
int64_t
currentId
=
(
t
*
myStrideT
+
u
)
*
dictSize
+
blankIdx
;
int64_t
nextId
=
(
t
*
myStrideT
+
u
)
*
dictSize
+
myBetaLabel
;
// main loading loop
if
(
t
>=
0
and
t
<
myFLen
and
u
>=
0
and
u
<
myGLen
){
if
(
u
==
myGLen
-
1
){
current
[
i
]
=
myX
[
currentId
];
}
else
if
(
t
==
myFLen
-
1
){
next
[
i
]
=
myX
[
nextId
];
}
else
{
current
[
i
]
=
myX
[
currentId
];
next
[
i
]
=
myX
[
nextId
];
}
}
}
}
// main computing loop
for
(
int
i
=
0
;
i
<
batchLdSize
;
++
i
){
// swap the pointer for double buffering
auto
sharedBetaRd
=
sharedBeta
[(
step
+
i
-
1
)
%
2
];
auto
sharedBetaWr
=
sharedBeta
[(
step
+
i
)
%
2
];
if
(
step
+
i
<
myFLen
+
myGLen
-
1
){
int64_t
t
=
myFLen
+
myGLen
-
(
step
+
i
)
-
2
-
u
;
if
(
t
>=
0
and
t
<
myFLen
and
u
>=
0
and
u
<
myGLen
){
// Eq(18) in [1]
if
(
u
==
myGLen
-
1
)
prvStepBeta
=
prvStepBeta
+
current
[
i
];
else
if
(
t
==
myFLen
-
1
)
prvStepBeta
=
sharedBetaRd
[
u
+
1
]
+
next
[
i
];
else
prvStepBeta
=
logSumExp
(
prvStepBeta
+
current
[
i
],
sharedBetaRd
[
u
+
1
]
+
next
[
i
]);
sharedBetaWr
[
u
]
=
prvStepBeta
;
myBeta
[
t
*
maxGLen
+
u
]
=
prvStepBeta
;
}
}
__syncthreads
();
}
}
if
(
u
==
0
)
loss
[
batch
]
=
-
prvStepBeta
;
}
}
// Vanilla transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere,
// hence only Eq(20) in [1] is implemented in this kernel.
// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time
// Since only gradients for the correct token and null token need to be updated, gradients at other
// locations are initialized to 0.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template
<
typename
scalar_t
,
typename
acc_t
>
__global__
void
transducer_loss_backward
(
const
scalar_t
*
x
,
const
scalar_t
*
lossGrad
,
const
int
*
audLen
,
const
int
*
txtLen
,
const
int
*
label
,
const
acc_t
*
alpha
,
const
acc_t
*
beta
,
const
int64_t
*
batchOffset
,
int64_t
dictSize
,
int64_t
blankIdx
,
int64_t
maxFLen
,
int64_t
maxGLen
,
bool
packedInput
,
scalar_t
*
xGrad
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
t
=
blockIdx
.
x
;
const
int
batch
=
blockIdx
.
y
;
const
int64_t
myFLen
=
audLen
[
batch
];
const
int64_t
myGLen
=
txtLen
[
batch
]
+
1
;
const
int64_t
myBatchOffset
=
packedInput
?
(
batch
==
0
?
0
:
batchOffset
[
batch
-
1
])
:
batch
*
maxFLen
*
maxGLen
;
const
int64_t
myStrideT
=
packedInput
?
myGLen
:
maxGLen
;
auto
myX
=
x
+
(
myBatchOffset
+
t
*
myStrideT
)
*
dictSize
;
auto
myAlpha
=
alpha
+
batch
*
maxFLen
*
maxGLen
;
auto
myBeta
=
beta
+
batch
*
maxFLen
*
maxGLen
;
auto
myXGrad
=
xGrad
+
(
myBatchOffset
+
t
*
myStrideT
)
*
dictSize
;
auto
myLabel
=
label
+
batch
*
(
maxGLen
-
1
);
int64_t
u
=
tid
;
while
(
t
<
myFLen
and
u
<
myGLen
){
// Do the update
// loss = -ln(Pr(y*|x))
acc_t
grad
=
std
::
log
(
lossGrad
[
batch
])
+
myAlpha
[
t
*
maxGLen
+
u
]
-
myBeta
[
0
];
if
(
u
!=
myGLen
-
1
)
myXGrad
[
u
*
dictSize
+
myLabel
[
u
]]
=
-
std
::
exp
(
grad
+
myBeta
[
t
*
maxGLen
+
u
+
1
]
+
myX
[
u
*
dictSize
+
myLabel
[
u
]]);
if
(
t
==
myFLen
-
1
and
u
==
myGLen
-
1
)
myXGrad
[
u
*
dictSize
+
blankIdx
]
=
-
std
::
exp
(
grad
+
myX
[
u
*
dictSize
+
blankIdx
]);
else
if
(
t
!=
myFLen
-
1
)
myXGrad
[
u
*
dictSize
+
blankIdx
]
=
-
std
::
exp
(
grad
+
myBeta
[(
t
+
1
)
*
maxGLen
+
u
]
+
myX
[
u
*
dictSize
+
blankIdx
]);
u
+=
blockDim
.
x
;
}
}
// Fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template
<
typename
scalar_t
,
typename
acc_t
>
__global__
void
transducer_loss_fused_backward
(
const
scalar_t
*
x
,
const
scalar_t
*
lossGrad
,
const
int
*
audLen
,
const
int
*
txtLen
,
const
int
*
label
,
const
acc_t
*
alpha
,
const
acc_t
*
beta
,
const
int64_t
*
batchOffset
,
int64_t
dictSize
,
int64_t
blankIdx
,
int64_t
maxFLen
,
int64_t
maxGLen
,
bool
packedInput
,
scalar_t
*
xGrad
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
u
=
blockIdx
.
x
;
const
int
t
=
blockIdx
.
y
;
const
int
batch
=
blockIdx
.
z
;
const
int64_t
myFLen
=
audLen
[
batch
];
const
int64_t
myGLen
=
txtLen
[
batch
]
+
1
;
const
int64_t
myBatchOffset
=
packedInput
?
(
batch
==
0
?
0
:
batchOffset
[
batch
-
1
])
:
batch
*
maxFLen
*
maxGLen
;
const
int64_t
myStrideT
=
packedInput
?
myGLen
:
maxGLen
;
__shared__
acc_t
commonFactor
,
myBetaTU
,
myBetaTUp1
,
myBetaTp1U
,
myLabelShared
;
auto
myXGrad
=
xGrad
+
(
myBatchOffset
+
t
*
myStrideT
+
u
)
*
dictSize
;
if
(
t
<
myFLen
and
u
<
myGLen
){
auto
myX
=
x
+
(
myBatchOffset
+
t
*
myStrideT
+
u
)
*
dictSize
;
auto
myAlpha
=
alpha
+
batch
*
maxFLen
*
maxGLen
;
auto
myBeta
=
beta
+
batch
*
maxFLen
*
maxGLen
;
auto
myLabel
=
label
+
batch
*
(
maxGLen
-
1
);
// load and store shared variables in SMEM
if
(
tid
==
0
){
commonFactor
=
std
::
log
(
lossGrad
[
batch
])
+
myAlpha
[
t
*
maxGLen
+
u
]
-
myBeta
[
0
];
myBetaTU
=
myBeta
[
t
*
maxGLen
+
u
];
myBetaTUp1
=
myBeta
[
t
*
maxGLen
+
u
+
1
];
myBetaTp1U
=
myBeta
[(
t
+
1
)
*
maxGLen
+
u
];
myLabelShared
=
myLabel
[
u
];
}
__syncthreads
();
for
(
int64_t
h
=
tid
;
h
<
dictSize
;
h
+=
blockDim
.
x
){
// Do the update
acc_t
grad
=
commonFactor
+
myX
[
h
];
// loss = -ln(Pr(y*|x))
acc_t
myGrad
=
std
::
exp
(
grad
+
myBetaTU
);
if
(
u
!=
myGLen
-
1
and
h
==
myLabelShared
){
myGrad
-=
std
::
exp
(
grad
+
myBetaTUp1
);
}
else
if
(
h
==
blankIdx
){
if
(
t
==
myFLen
-
1
and
u
==
myGLen
-
1
)
myGrad
-=
std
::
exp
(
grad
);
else
if
(
t
!=
myFLen
-
1
)
myGrad
-=
std
::
exp
(
grad
+
myBetaTp1U
);
}
myXGrad
[
h
]
=
myGrad
;
}
}
else
if
(
!
packedInput
){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for
(
int64_t
h
=
tid
;
h
<
dictSize
;
h
+=
blockDim
.
x
){
myXGrad
[
h
]
=
0
;
}
}
}
// Vectorized version of fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template
<
typename
scalar_t
,
typename
acc_t
,
typename
vec_t
,
int
V
>
__global__
void
transducer_loss_fused_vec_backward
(
const
scalar_t
*
x
,
const
scalar_t
*
lossGrad
,
const
int
*
audLen
,
const
int
*
txtLen
,
const
int
*
label
,
const
acc_t
*
alpha
,
const
acc_t
*
beta
,
const
int64_t
*
batchOffset
,
int64_t
dictSize
,
int64_t
blankIdx
,
int64_t
maxFLen
,
int64_t
maxGLen
,
bool
packedInput
,
scalar_t
*
xGrad
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
u
=
blockIdx
.
x
;
const
int
t
=
blockIdx
.
y
;
const
int
batch
=
blockIdx
.
z
;
const
int64_t
myFLen
=
audLen
[
batch
];
const
int64_t
myGLen
=
txtLen
[
batch
]
+
1
;
const
int64_t
myBatchOffset
=
packedInput
?
(
batch
==
0
?
0
:
batchOffset
[
batch
-
1
])
:
batch
*
maxFLen
*
maxGLen
;
const
int64_t
myStrideT
=
packedInput
?
myGLen
:
maxGLen
;
__shared__
acc_t
commonFactor
,
myBetaTU
,
myBetaTUp1
,
myBetaTp1U
,
myLabelShared
;
auto
myXGrad
=
xGrad
+
(
myBatchOffset
+
t
*
myStrideT
+
u
)
*
dictSize
;
auto
myX
=
x
+
(
myBatchOffset
+
t
*
myStrideT
+
u
)
*
dictSize
;
auto
myAlpha
=
alpha
+
batch
*
maxFLen
*
maxGLen
;
auto
myBeta
=
beta
+
batch
*
maxFLen
*
maxGLen
;
auto
myLabel
=
label
+
batch
*
(
maxGLen
-
1
);
// Variabels for vectorization
scalar_t
myXBuffer
[
V
],
myXGradBuffer
[
V
];
auto
myXVec
=
reinterpret_cast
<
vec_t
const
*>
(
myX
);
auto
myXGradVec
=
reinterpret_cast
<
vec_t
*>
(
myXGrad
);
auto
myXBufferVec
=
reinterpret_cast
<
vec_t
*>
(
myXBuffer
);
auto
myXGradBufferVec
=
reinterpret_cast
<
vec_t
*>
(
myXGradBuffer
);
if
(
t
<
myFLen
and
u
<
myGLen
){
// load and store shared variables in SMEM
if
(
tid
==
0
){
commonFactor
=
std
::
log
(
lossGrad
[
batch
])
+
myAlpha
[
t
*
maxGLen
+
u
]
-
myBeta
[
0
];
myBetaTU
=
myBeta
[
t
*
maxGLen
+
u
];
if
(
t
!=
myFLen
-
1
)
myBetaTp1U
=
myBeta
[(
t
+
1
)
*
maxGLen
+
u
];
if
(
u
!=
myGLen
-
1
){
myBetaTUp1
=
myBeta
[
t
*
maxGLen
+
u
+
1
];
myLabelShared
=
myLabel
[
u
];
}
}
__syncthreads
();
#pragma unroll
for
(
int64_t
h0
=
tid
*
V
;
h0
<
dictSize
;
h0
+=
blockDim
.
x
*
V
){
// Load myX in a vector form
*
myXBufferVec
=
myXVec
[
h0
/
V
];
// Do the update for a vector of input
#pragma unroll
for
(
int
i
=
0
;
i
<
V
;
++
i
){
auto
h
=
h0
+
i
;
acc_t
grad
=
commonFactor
+
myXBuffer
[
i
];
// loss = -ln(Pr(y*|x))
acc_t
myGrad
=
std
::
exp
(
grad
+
myBetaTU
);
if
(
u
!=
myGLen
-
1
and
h
==
myLabelShared
){
myGrad
-=
std
::
exp
(
grad
+
myBetaTUp1
);
}
else
if
(
h
==
blankIdx
){
if
(
t
==
myFLen
-
1
and
u
==
myGLen
-
1
)
myGrad
-=
std
::
exp
(
grad
);
else
if
(
t
!=
myFLen
-
1
)
myGrad
-=
std
::
exp
(
grad
+
myBetaTp1U
);
}
myXGradBuffer
[
i
]
=
myGrad
;
}
// Store myXGrad in a vector form
myXGradVec
[
h0
/
V
]
=
*
myXGradBufferVec
;
}
}
else
if
(
!
packedInput
){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for
(
int64_t
h0
=
tid
*
V
;
h0
<
dictSize
;
h0
+=
blockDim
.
x
*
V
){
myXGradVec
[
h0
/
V
]
=
0
;
}
}
}
std
::
vector
<
torch
::
Tensor
>
transducer_loss_cuda_forward
(
torch
::
Tensor
x
,
torch
::
Tensor
label
,
torch
::
Tensor
audLen
,
torch
::
Tensor
txtLen
,
torch
::
Tensor
batchOffset
,
int
maxFLen
,
int
blankIdx
,
int
opt
,
bool
packedInput
){
auto
scalarType
=
x
.
scalar_type
();
auto
tensorOpt
=
x
.
options
();
const
int
batchSize
=
label
.
size
(
0
);
const
int
maxGLen
=
label
.
size
(
1
)
+
1
;
const
int
dictSize
=
x
.
size
(
-
1
);
TORCH_CHECK
(
blankIdx
>=
0
and
blankIdx
<
dictSize
,
"Expected blank index to be in the range of 0 to "
,
dictSize
-
1
,
", but got "
,
blankIdx
);
TORCH_CHECK
(
opt
==
-
1
or
opt
==
0
or
opt
==
1
,
"Got an invalid optimization level "
,
opt
);
// The data type of alpha and beta will be resolved at dispatch time,
// hence defined here and assigned later
torch
::
Tensor
alpha
;
torch
::
Tensor
beta
;
torch
::
Tensor
loss
=
torch
::
empty
({
batchSize
},
tensorOpt
);
const
auto
deviceProperties
=
at
::
cuda
::
getCurrentDeviceProperties
();
const
auto
maxThreadPerBlock
=
deviceProperties
->
maxThreadsPerBlock
;
const
auto
maxSmemPerBlock
=
deviceProperties
->
sharedMemPerBlock
;
const
auto
batchOffsetPtr
=
packedInput
?
batchOffset
.
data_ptr
<
int64_t
>
()
:
nullptr
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
scalarType
,
"transducer_loss_cuda_forward"
,
([
&
]
{
// resolve accumulation type
using
acc_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
auto
accType
=
c10
::
CppTypeToScalarType
<
acc_t
>::
value
;
auto
accTensorOpt
=
tensorOpt
.
dtype
(
accType
);
alpha
=
torch
::
empty
({
batchSize
,
maxFLen
,
maxGLen
},
accTensorOpt
);
beta
=
torch
::
empty
({
batchSize
,
maxFLen
,
maxGLen
},
accTensorOpt
);
// decide what kernel to launch based on the problem size
// if the required SMEM size or number threads exceeds the limit, fall back to the vanilla
// kernel.
const
auto
smemSize
=
2
*
maxGLen
*
sizeof
(
acc_t
);
const
auto
optFallBack
=
(
maxGLen
>
maxThreadPerBlock
or
smemSize
>
maxSmemPerBlock
)
?
0
:
(
opt
==
-
1
)
?
1
:
opt
;
const
int
threads
=
std
::
min
(
maxThreadPerBlock
,
maxGLen
);
const
dim3
blocks
(
2
,
batchSize
,
1
);
if
(
optFallBack
==
0
)
transducer_loss_forward
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
label
.
data_ptr
<
int
>
(),
audLen
.
data_ptr
<
int
>
(),
txtLen
.
data_ptr
<
int
>
(),
batchOffsetPtr
,
dictSize
,
blankIdx
,
maxFLen
,
maxGLen
,
packedInput
,
alpha
.
data_ptr
<
acc_t
>
(),
beta
.
data_ptr
<
acc_t
>
(),
loss
.
data_ptr
<
scalar_t
>
());
else
if
(
optFallBack
==
1
)
transducer_loss_batch_load_forward
<
scalar_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
smemSize
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
label
.
data_ptr
<
int
>
(),
audLen
.
data_ptr
<
int
>
(),
txtLen
.
data_ptr
<
int
>
(),
batchOffsetPtr
,
dictSize
,
blankIdx
,
maxFLen
,
maxGLen
,
packedInput
,
alpha
.
data_ptr
<
acc_t
>
(),
beta
.
data_ptr
<
acc_t
>
(),
loss
.
data_ptr
<
scalar_t
>
());
}));
C10_CUDA_CHECK
(
cudaGetLastError
());
return
{
alpha
,
beta
,
loss
};
}
torch
::
Tensor
transducer_loss_cuda_backward
(
torch
::
Tensor
x
,
torch
::
Tensor
lossGrad
,
torch
::
Tensor
alpha
,
torch
::
Tensor
beta
,
torch
::
Tensor
audLen
,
torch
::
Tensor
txtLen
,
torch
::
Tensor
label
,
torch
::
Tensor
batchOffset
,
int
maxFLen
,
int
blankIdx
,
int
opt
,
bool
fuseSoftmaxBackward
,
bool
packedInput
){
auto
dtype
=
x
.
scalar_type
();
torch
::
Tensor
xGrad
;
const
int
batchSize
=
label
.
size
(
0
);
const
int
maxGLen
=
label
.
size
(
1
)
+
1
;
const
int
dictSize
=
x
.
size
(
-
1
);
const
auto
deviceProperties
=
at
::
cuda
::
getCurrentDeviceProperties
();
const
int
maxThreadPerBlock
=
deviceProperties
->
maxThreadsPerBlock
;
const
int
warpSize
=
deviceProperties
->
warpSize
;
const
auto
batchOffsetPtr
=
packedInput
?
batchOffset
.
data_ptr
<
int64_t
>
()
:
nullptr
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
fuseSoftmaxBackward
){
// alloc empty tensors for performance, hence need to ensure zeros are writtern to
// don't-care region in the kernel.
xGrad
=
torch
::
empty_like
(
x
);
// Would like each thread to work on 4 hidden units
const
int
workPerThread
=
4
;
// Don't want to have more than 128 threads per thread block
const
int
maxThreadPerElmt
=
std
::
min
(
128
,
maxThreadPerBlock
);
const
int
threads
=
std
::
min
(
maxThreadPerElmt
,
std
::
max
(
warpSize
,
(
dictSize
+
workPerThread
-
1
)
/
workPerThread
));
const
dim3
blocks
(
maxGLen
,
maxFLen
,
batchSize
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
dtype
,
"transducer_loss_cuda_backward"
,
([
&
]
{
using
vec_t
=
uint64_t
;
using
acc_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
constexpr
int
vectFactor
=
sizeof
(
vec_t
)
/
sizeof
(
scalar_t
);
constexpr
int
vecAlignment
=
std
::
alignment_of
<
vec_t
>::
value
;
// if all input and output tensors meet the alignment requirement
bool
memAlign
=
reinterpret_cast
<
uint64_t
>
(
x
.
data_ptr
<
scalar_t
>
())
%
vecAlignment
==
0
and
reinterpret_cast
<
uint64_t
>
(
xGrad
.
data_ptr
<
scalar_t
>
())
%
vecAlignment
==
0
;
if
(
vectFactor
>
1
and
dictSize
%
vectFactor
==
0
and
memAlign
){
transducer_loss_fused_vec_backward
<
scalar_t
,
acc_t
,
vec_t
,
vectFactor
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
lossGrad
.
data_ptr
<
scalar_t
>
(),
audLen
.
data_ptr
<
int
>
(),
txtLen
.
data_ptr
<
int
>
(),
label
.
data_ptr
<
int
>
(),
alpha
.
data_ptr
<
acc_t
>
(),
beta
.
data_ptr
<
acc_t
>
(),
batchOffsetPtr
,
dictSize
,
blankIdx
,
maxFLen
,
maxGLen
,
packedInput
,
xGrad
.
data_ptr
<
scalar_t
>
());
}
else
{
transducer_loss_fused_backward
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
lossGrad
.
data_ptr
<
scalar_t
>
(),
audLen
.
data_ptr
<
int
>
(),
txtLen
.
data_ptr
<
int
>
(),
label
.
data_ptr
<
int
>
(),
alpha
.
data_ptr
<
acc_t
>
(),
beta
.
data_ptr
<
acc_t
>
(),
batchOffsetPtr
,
dictSize
,
blankIdx
,
maxFLen
,
maxGLen
,
packedInput
,
xGrad
.
data_ptr
<
scalar_t
>
());
}
}));
}
else
{
// for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize
// the tensor with all zeros.
xGrad
=
torch
::
zeros_like
(
x
);
// don't launch more threads than needed.
const
int
threads
=
std
::
min
(
maxThreadPerBlock
,
maxGLen
);
const
dim3
blocks
(
maxFLen
,
batchSize
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
dtype
,
"transducer_loss_cuda_backward"
,
([
&
]
{
using
acc_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
transducer_loss_backward
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
lossGrad
.
data_ptr
<
scalar_t
>
(),
audLen
.
data_ptr
<
int
>
(),
txtLen
.
data_ptr
<
int
>
(),
label
.
data_ptr
<
int
>
(),
alpha
.
data_ptr
<
acc_t
>
(),
beta
.
data_ptr
<
acc_t
>
(),
batchOffsetPtr
,
dictSize
,
blankIdx
,
maxFLen
,
maxGLen
,
packedInput
,
xGrad
.
data_ptr
<
scalar_t
>
());
}));
}
C10_CUDA_CHECK
(
cudaGetLastError
());
return
xGrad
;
}
apex/contrib/csrc/xentropy/interface.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <torch/extension.h>
// CUDA forward declarations
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_cuda
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
half_to_float
);
at
::
Tensor
softmax_xentropy_backward_cuda
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
half_to_float
)
{
CHECK_CUDA
(
input
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_cuda
(
input
,
labels
,
smoothing
,
half_to_float
);
}
at
::
Tensor
softmax_xentropy_backward
(
const
at
::
Tensor
&
grad_loss
,
const
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
)
{
CHECK_CUDA
(
grad_loss
);
CHECK_CUDA
(
logits
);
CHECK_INPUT
(
max_log_sum_exp
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_backward_cuda
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
smoothing
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
softmax_xentropy_forward
,
"Softmax cross entropy loss with label smoothing forward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_xentropy_backward
,
"Softmax cross entropy loss with label smoothing backward (CUDA)"
);
}
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment