Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
bdd481d1
Unverified
Commit
bdd481d1
authored
May 21, 2020
by
Jeff Daily
Committed by
GitHub
May 21, 2020
Browse files
pass all TensorListMetadata as pointer to pinned host memory (#13)
parent
b2b55439
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
126 additions
and
121 deletions
+126
-121
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+15
-15
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
+15
-15
csrc/multi_tensor_adagrad.cu
csrc/multi_tensor_adagrad.cu
+7
-7
csrc/multi_tensor_adam.cu
csrc/multi_tensor_adam.cu
+9
-9
csrc/multi_tensor_apply.cuh
csrc/multi_tensor_apply.cuh
+7
-2
csrc/multi_tensor_axpby_kernel.cu
csrc/multi_tensor_axpby_kernel.cu
+7
-7
csrc/multi_tensor_l2norm_kernel.cu
csrc/multi_tensor_l2norm_kernel.cu
+12
-12
csrc/multi_tensor_lamb.cu
csrc/multi_tensor_lamb.cu
+15
-15
csrc/multi_tensor_lamb_stage_1.cu
csrc/multi_tensor_lamb_stage_1.cu
+10
-10
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+7
-7
csrc/multi_tensor_novograd.cu
csrc/multi_tensor_novograd.cu
+8
-8
csrc/multi_tensor_scale_kernel.cu
csrc/multi_tensor_scale_kernel.cu
+6
-6
csrc/multi_tensor_sgd_kernel.cu
csrc/multi_tensor_sgd_kernel.cu
+8
-8
No files found.
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
View file @
bdd481d1
...
@@ -76,7 +76,7 @@ struct AdamFunctor
...
@@ -76,7 +76,7 @@ struct AdamFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
DEPTH
>
&
tl
,
TensorListMetadata
<
DEPTH
>
*
tl
,
const
float
b1
,
const
float
b1
,
const
float
b2
,
const
float
b2
,
const
float
eps
,
const
float
eps
,
...
@@ -85,21 +85,21 @@ struct AdamFunctor
...
@@ -85,21 +85,21 @@ struct AdamFunctor
adamMode_t
mode
,
adamMode_t
mode
,
const
float
decay
)
const
float
decay
)
{
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
v
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
v
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
->
addresses
[
3
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
p_copy
=
NULL
;
GRAD_T
*
p_copy
=
NULL
;
if
(
DEPTH
==
5
)
{
if
(
DEPTH
==
5
)
{
p_copy
=
(
GRAD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
p_copy
=
(
GRAD_T
*
)
tl
->
addresses
[
4
][
tensor_loc
];
p_copy
+=
chunk_idx
*
chunk_size
;
p_copy
+=
chunk_idx
*
chunk_size
;
}
}
...
@@ -736,17 +736,17 @@ struct MaybeCastFunctor
...
@@ -736,17 +736,17 @@ struct MaybeCastFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
overflow_flag
,
volatile
int
*
overflow_flag
,
TensorListMetadata
<
DEPTH
>
&
tl
)
TensorListMetadata
<
DEPTH
>
*
tl
)
{
{
if
(
overflow_flag
&&
*
overflow_flag
!=
0
)
return
;
if
(
overflow_flag
&&
*
overflow_flag
!=
0
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
FROM_T
*
p_in
=
(
FROM_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
FROM_T
*
p_in
=
(
FROM_T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
p_in
+=
chunk_idx
*
chunk_size
;
p_in
+=
chunk_idx
*
chunk_size
;
TO_T
*
p_out
=
(
TO_T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
TO_T
*
p_out
=
(
TO_T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p_out
+=
chunk_idx
*
chunk_size
;
p_out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
View file @
bdd481d1
...
@@ -32,7 +32,7 @@ struct LAMBStage1Functor
...
@@ -32,7 +32,7 @@ struct LAMBStage1Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>
&
tl
,
TensorListMetadata
<
4
>
*
tl
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
beta2
,
const
float
beta3
,
const
float
beta3
,
...
@@ -48,22 +48,22 @@ struct LAMBStage1Functor
...
@@ -48,22 +48,22 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
float
clipped_global_grad_norm
=
global_grad_norm
>
max_global_grad_norm
?
global_grad_norm
/
max_global_grad_norm
:
1.0
f
;
float
clipped_global_grad_norm
=
global_grad_norm
>
max_global_grad_norm
?
global_grad_norm
/
max_global_grad_norm
:
1.0
f
;
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
T
*
v
=
(
T
*
)
tl
->
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
v
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -147,7 +147,7 @@ struct LAMBStage2Functor
...
@@ -147,7 +147,7 @@ struct LAMBStage2Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
TensorListMetadata
<
2
>
*
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
,
const
float
learning_rate
,
...
@@ -157,10 +157,10 @@ struct LAMBStage2Functor
...
@@ -157,10 +157,10 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
tensor_num
=
tl
->
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
MATH_T
ratio
=
learning_rate
;
MATH_T
ratio
=
learning_rate
;
// apply adaptive learning rate to parameters with non-zero weight decay
// apply adaptive learning rate to parameters with non-zero weight decay
...
@@ -171,10 +171,10 @@ struct LAMBStage2Functor
...
@@ -171,10 +171,10 @@ struct LAMBStage2Functor
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
}
T
*
update
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
update
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
update
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
csrc/multi_tensor_adagrad.cu
View file @
bdd481d1
...
@@ -23,20 +23,20 @@ using MATH_T = float;
...
@@ -23,20 +23,20 @@ using MATH_T = float;
template
<
typename
T
>
struct
AdagradFunctor
{
template
<
typename
T
>
struct
AdagradFunctor
{
__device__
__forceinline__
void
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
*
tl
,
const
float
epsilon
,
const
float
lr
,
adagradMode_t
mode
,
const
float
epsilon
,
const
float
lr
,
adagradMode_t
mode
,
const
float
weight_decay
)
{
const
float
weight_decay
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
h
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
h
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
h
+=
chunk_idx
*
chunk_size
;
h
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
csrc/multi_tensor_adam.cu
View file @
bdd481d1
...
@@ -26,7 +26,7 @@ struct AdamFunctor
...
@@ -26,7 +26,7 @@ struct AdamFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>
&
tl
,
TensorListMetadata
<
4
>
*
tl
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
beta2
,
const
float
beta1_correction
,
const
float
beta1_correction
,
...
@@ -40,24 +40,24 @@ struct AdamFunctor
...
@@ -40,24 +40,24 @@ struct AdamFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
// potentially use to pass in list of scalar
// potentially use to pass in list of scalar
// int tensor_num = tl
.
start_tensor_this_launch + tensor_loc;
// int tensor_num = tl
->
start_tensor_this_launch + tensor_loc;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
T
*
v
=
(
T
*
)
tl
->
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
v
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
csrc/multi_tensor_apply.cuh
View file @
bdd481d1
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#include <THC/THC.h>
#include "compat.h"
#include "compat.h"
#include <assert.h>
#include <assert.h>
...
@@ -29,7 +30,7 @@ template<typename T, typename U, typename... ArgTypes>
...
@@ -29,7 +30,7 @@ template<typename T, typename U, typename... ArgTypes>
__global__
void
multi_tensor_apply_kernel
(
__global__
void
multi_tensor_apply_kernel
(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_flag
,
volatile
int
*
noop_flag
,
T
tl
,
T
*
tl
,
U
callable
,
U
callable
,
ArgTypes
...
args
)
ArgTypes
...
args
)
{
{
...
@@ -104,11 +105,15 @@ void multi_tensor_apply(
...
@@ -104,11 +105,15 @@ void multi_tensor_apply(
bool
last_chunk
=
(
t
==
ntensors
-
1
&&
chunk
==
chunks_this_tensor
-
1
);
bool
last_chunk
=
(
t
==
ntensors
-
1
&&
chunk
==
chunks_this_tensor
-
1
);
if
(
tensors_full
||
blocks_full
||
last_chunk
)
if
(
tensors_full
||
blocks_full
||
last_chunk
)
{
{
auto
storage
=
at
::
empty
(
sizeof
(
tl
),
c10
::
TensorOptions
(
at
::
kStrided
).
dtype
(
at
::
kByte
).
device
(
at
::
kCPU
).
pinned_memory
(
true
));
auto
tl_as_host_pinned_ptr
=
static_cast
<
decltype
(
tl
)
*>
(
storage
.
data_ptr
());
memcpy
(
tl_as_host_pinned_ptr
,
&
tl
,
sizeof
(
tl
));
AT_CUDA_CHECK
(
THCCachingHostAllocator_recordEvent
(
tl_as_host_pinned_ptr
,
stream
));
// using accscalar_t = acc_type<scalar_t, true>;
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel
<<<
loc_block_info
,
block_size
,
0
,
stream
>>>
(
multi_tensor_apply_kernel
<<<
loc_block_info
,
block_size
,
0
,
stream
>>>
(
chunk_size
,
chunk_size
,
noop_flag
.
DATA_PTR
<
int
>
(),
noop_flag
.
DATA_PTR
<
int
>
(),
tl
,
tl
_as_host_pinned_ptr
,
callable
,
callable
,
args
...);
args
...);
...
...
csrc/multi_tensor_axpby_kernel.cu
View file @
bdd481d1
...
@@ -30,7 +30,7 @@ struct AxpbyFunctor
...
@@ -30,7 +30,7 @@ struct AxpbyFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
TensorListMetadata
<
3
>
*
tl
,
float
a
,
float
a
,
float
b
,
float
b
,
int
arg_to_check
)
int
arg_to_check
)
...
@@ -39,17 +39,17 @@ struct AxpbyFunctor
...
@@ -39,17 +39,17 @@ struct AxpbyFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
->
addresses
[
0
][
tensor_loc
];
x
+=
chunk_idx
*
chunk_size
;
x
+=
chunk_idx
*
chunk_size
;
y_t
*
y
=
(
y_t
*
)
tl
.
addresses
[
1
][
tensor_loc
];
y_t
*
y
=
(
y_t
*
)
tl
->
addresses
[
1
][
tensor_loc
];
y
+=
chunk_idx
*
chunk_size
;
y
+=
chunk_idx
*
chunk_size
;
out_t
*
out
=
(
out_t
*
)
tl
.
addresses
[
2
][
tensor_loc
];
out_t
*
out
=
(
out_t
*
)
tl
->
addresses
[
2
][
tensor_loc
];
out
+=
chunk_idx
*
chunk_size
;
out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
csrc/multi_tensor_l2norm_kernel.cu
View file @
bdd481d1
...
@@ -30,7 +30,7 @@ struct L2NormFunctor
...
@@ -30,7 +30,7 @@ struct L2NormFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
1
>
&
tl
,
TensorListMetadata
<
1
>
*
tl
,
float
*
output
,
float
*
output
,
float
*
output_per_tensor
,
float
*
output_per_tensor
,
bool
per_tensor
,
bool
per_tensor
,
...
@@ -40,11 +40,11 @@ struct L2NormFunctor
...
@@ -40,11 +40,11 @@ struct L2NormFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
->
addresses
[
0
][
tensor_loc
];
x
+=
chunk_idx
*
chunk_size
;
x
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -103,7 +103,7 @@ struct L2NormFunctor
...
@@ -103,7 +103,7 @@ struct L2NormFunctor
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
output
[
blockIdx
.
x
]
+=
final
;
output
[
blockIdx
.
x
]
+=
final
;
if
(
per_tensor
)
if
(
per_tensor
)
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
output_per_tensor
[(
tl
->
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
}
}
}
}
};
};
...
@@ -115,7 +115,7 @@ struct MaxNormFunctor
...
@@ -115,7 +115,7 @@ struct MaxNormFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
1
>
&
tl
,
TensorListMetadata
<
1
>
*
tl
,
float
*
output
,
float
*
output
,
float
*
output_per_tensor
,
float
*
output_per_tensor
,
bool
per_tensor
,
bool
per_tensor
,
...
@@ -125,11 +125,11 @@ struct MaxNormFunctor
...
@@ -125,11 +125,11 @@ struct MaxNormFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
->
addresses
[
0
][
tensor_loc
];
x
+=
chunk_idx
*
chunk_size
;
x
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -188,7 +188,7 @@ struct MaxNormFunctor
...
@@ -188,7 +188,7 @@ struct MaxNormFunctor
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
output
[
blockIdx
.
x
]
=
fmaxf
(
fabsf
(
output
[
blockIdx
.
x
]),
fabsf
(
final
));
output
[
blockIdx
.
x
]
=
fmaxf
(
fabsf
(
output
[
blockIdx
.
x
]),
fabsf
(
final
));
if
(
per_tensor
)
if
(
per_tensor
)
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
output_per_tensor
[(
tl
->
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
}
}
}
}
};
};
...
...
csrc/multi_tensor_lamb.cu
View file @
bdd481d1
...
@@ -43,7 +43,7 @@ struct LAMBStage1Functor
...
@@ -43,7 +43,7 @@ struct LAMBStage1Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>
&
tl
,
TensorListMetadata
<
4
>
*
tl
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
beta2
,
const
float
beta3
,
const
float
beta3
,
...
@@ -59,22 +59,22 @@ struct LAMBStage1Functor
...
@@ -59,22 +59,22 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
float
clipped_global_grad_norm
=
(
*
global_grad_norm
)
>
max_global_grad_norm
?
(
*
global_grad_norm
)
/
max_global_grad_norm
:
1.0
f
;
float
clipped_global_grad_norm
=
(
*
global_grad_norm
)
>
max_global_grad_norm
?
(
*
global_grad_norm
)
/
max_global_grad_norm
:
1.0
f
;
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
T
*
v
=
(
T
*
)
tl
->
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
v
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -236,7 +236,7 @@ struct LAMBStage2Functor
...
@@ -236,7 +236,7 @@ struct LAMBStage2Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
TensorListMetadata
<
2
>
*
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
learning_rate
)
...
@@ -245,19 +245,19 @@ struct LAMBStage2Functor
...
@@ -245,19 +245,19 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
tensor_num
=
tl
->
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
MATH_T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
MATH_T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
T
*
update
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
update
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
update
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
csrc/multi_tensor_lamb_stage_1.cu
View file @
bdd481d1
...
@@ -20,7 +20,7 @@ struct LAMBStage1Functor
...
@@ -20,7 +20,7 @@ struct LAMBStage1Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
5
>
&
tl
,
TensorListMetadata
<
5
>
*
tl
,
const
float
*
per_tensor_decay
,
const
float
*
per_tensor_decay
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
beta2
,
...
@@ -33,26 +33,26 @@ struct LAMBStage1Functor
...
@@ -33,26 +33,26 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
tensor_num
=
tl
->
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
float
decay
=
per_tensor_decay
[
tensor_num
];
float
decay
=
per_tensor_decay
[
tensor_num
];
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
T
*
v
=
(
T
*
)
tl
->
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
v
+=
chunk_idx
*
chunk_size
;
UPD_T
*
update
=
(
UPD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
UPD_T
*
update
=
(
UPD_T
*
)
tl
->
addresses
[
4
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
update
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
csrc/multi_tensor_lamb_stage_2.cu
View file @
bdd481d1
...
@@ -21,7 +21,7 @@ struct LAMBStage2Functor
...
@@ -21,7 +21,7 @@ struct LAMBStage2Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
TensorListMetadata
<
2
>
*
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
learning_rate
)
...
@@ -30,19 +30,19 @@ struct LAMBStage2Functor
...
@@ -30,19 +30,19 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
tensor_num
=
tl
->
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
UPD_T
*
update
=
(
UPD_T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
UPD_T
*
update
=
(
UPD_T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
update
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
csrc/multi_tensor_novograd.cu
View file @
bdd481d1
...
@@ -35,7 +35,7 @@ struct NovoGradFunctor
...
@@ -35,7 +35,7 @@ struct NovoGradFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
TensorListMetadata
<
3
>
*
tl
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
beta2
,
const
float
beta3
,
const
float
beta3
,
...
@@ -51,20 +51,20 @@ struct NovoGradFunctor
...
@@ -51,20 +51,20 @@ struct NovoGradFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
tensor_num
=
tl
->
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
float
grad_norm
=
per_tensor_grad_norm
[
tensor_num
];
float
grad_norm
=
per_tensor_grad_norm
[
tensor_num
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
csrc/multi_tensor_scale_kernel.cu
View file @
bdd481d1
...
@@ -32,21 +32,21 @@ struct ScaleFunctor
...
@@ -32,21 +32,21 @@ struct ScaleFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
TensorListMetadata
<
2
>
*
tl
,
float
scale
)
float
scale
)
{
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
in_t
*
in
=
(
in_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
in_t
*
in
=
(
in_t
*
)
tl
->
addresses
[
0
][
tensor_loc
];
in
+=
chunk_idx
*
chunk_size
;
in
+=
chunk_idx
*
chunk_size
;
out_t
*
out
=
(
out_t
*
)
tl
.
addresses
[
1
][
tensor_loc
];
out_t
*
out
=
(
out_t
*
)
tl
->
addresses
[
1
][
tensor_loc
];
out
+=
chunk_idx
*
chunk_size
;
out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
csrc/multi_tensor_sgd_kernel.cu
View file @
bdd481d1
...
@@ -32,7 +32,7 @@ struct SGDFunctor
...
@@ -32,7 +32,7 @@ struct SGDFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
N
>
&
tl
,
TensorListMetadata
<
N
>
*
tl
,
float
wd
,
float
wd
,
float
momentum
,
float
momentum
,
float
dampening
,
float
dampening
,
...
@@ -45,23 +45,23 @@ struct SGDFunctor
...
@@ -45,23 +45,23 @@ struct SGDFunctor
// Early exit if we don't need to do anything
// Early exit if we don't need to do anything
if
(
*
noop_gmem
)
return
;
if
(
*
noop_gmem
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
T_grad
*
grad_in
=
(
T_grad
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T_grad
*
grad_in
=
(
T_grad
*
)
tl
->
addresses
[
0
][
tensor_loc
];
grad_in
+=
chunk_idx
*
chunk_size
;
grad_in
+=
chunk_idx
*
chunk_size
;
T_weight
*
weight_in
=
(
T_weight
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T_weight
*
weight_in
=
(
T_weight
*
)
tl
->
addresses
[
1
][
tensor_loc
];
weight_in
+=
chunk_idx
*
chunk_size
;
weight_in
+=
chunk_idx
*
chunk_size
;
T_weight
*
mom_in
=
(
T_weight
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T_weight
*
mom_in
=
(
T_weight
*
)
tl
->
addresses
[
2
][
tensor_loc
];
mom_in
+=
chunk_idx
*
chunk_size
;
mom_in
+=
chunk_idx
*
chunk_size
;
at
::
Half
*
model_weights_out
=
nullptr
;
at
::
Half
*
model_weights_out
=
nullptr
;
if
(
N
==
4
)
if
(
N
==
4
)
{
{
model_weights_out
=
(
at
::
Half
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
=
(
at
::
Half
*
)
tl
->
addresses
[
3
][
tensor_loc
];
model_weights_out
+=
chunk_idx
*
chunk_size
;
model_weights_out
+=
chunk_idx
*
chunk_size
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment