Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b1a83375
Commit
b1a83375
authored
Mar 16, 2021
by
Vijay Korthikanti
Browse files
softmax data load/store optimization
parent
1a2cb60c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
161 additions
and
77 deletions
+161
-77
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+80
-42
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+81
-35
No files found.
megatron/fused_kernels/scaled_masked_softmax.h
View file @
b1a83375
...
@@ -26,6 +26,23 @@
...
@@ -26,6 +26,23 @@
namespace
{
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
float
,
1
>
(
float
*
dst
,
const
float
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
4
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
int
log2_ceil
(
int
value
)
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
...
@@ -90,13 +107,14 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -90,13 +107,14 @@ __global__ void scaled_masked_softmax_warp_forward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
pad_first_batch
=
0
;
int
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
else
{
// gpt2 style
}
else
{
// gpt2 style
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
}
...
@@ -110,29 +128,40 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -110,29 +128,40 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
local_idx
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
uint8_t
temp_mask
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
if
(
mask
[
itr_idx
]
!=
1
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
elements
[
i
][
it
]
=
(
acc_t
)
src
[
itr_idx
]
*
scale
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
}
else
{
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_mask
,
mask
+
itr_idx
);
elements
[
i
][
it
]
=
-
10000.0
;
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
temp_mask
[
element
]
!=
1
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
10000.0
;
}
}
}
else
{
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
}
}
}
...
@@ -161,15 +190,20 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -161,15 +190,20 @@ __global__ void scaled_masked_softmax_warp_forward(
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
}
else
{
break
;
break
;
}
}
...
@@ -192,6 +226,7 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -192,6 +226,7 @@ __global__ void scaled_masked_softmax_warp_backward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
...
@@ -207,36 +242,34 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -207,36 +242,34 @@ __global__ void scaled_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
element_count
+
local_idx
;
int
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
}
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
}
else
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
#pragma unroll
}
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
#pragma unroll
}
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
#pragma unroll
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
}
else
{
}
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
}
}
...
@@ -257,11 +290,16 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -257,11 +290,16 @@ __global__ void scaled_masked_softmax_warp_backward(
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
if
(
element_index
<
element_count
)
{
// compute gradients
// compute gradients
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
}
...
@@ -299,8 +337,8 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -299,8 +337,8 @@ void dispatch_scaled_masked_softmax_forward(
constexpr
int
threads_per_block
=
128
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
...
@@ -388,7 +426,7 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -388,7 +426,7 @@ void dispatch_scaled_masked_softmax_backward(
constexpr
int
threads_per_block
=
128
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
batch_count
/
batches_per_block
;
int
blocks
=
batch_count
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
b1a83375
...
@@ -26,6 +26,27 @@
...
@@ -26,6 +26,27 @@
namespace
{
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
float
,
1
>
(
float
*
dst
,
const
float
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
4
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
__half
,
4
>
(
__half
*
dst
)
{
*
((
float2
*
)
dst
)
=
0
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
int
log2_ceil
(
int
value
)
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
...
@@ -73,7 +94,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
...
@@ -73,7 +94,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
* Extended softmax (from native aten pytorch) with following additional features
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 1) input scaling
* 2) Implicit time (diagonal masking)
* 2) Implicit time (diagonal masking)
*/
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
output_t
*
dst
,
...
@@ -89,6 +110,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -89,6 +110,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
...
@@ -103,22 +125,33 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -103,22 +125,33 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
elements
[
i
][
it
]
=
(
acc_t
)
src
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
scale
;
int
itr_idx
=
i
*
element_count
*
stride
+
it
*
WARP_SIZE
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
}
else
{
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
}
}
}
...
@@ -140,26 +173,33 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -140,26 +173,33 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
if
(
element_index
<
local_seq
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
else
if
(
element_index
<
element_count
)
{
}
else
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
0
;
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
)
}
else
{
}
else
{
break
;
break
;
}
}
...
@@ -183,6 +223,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -183,6 +223,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
...
@@ -197,37 +238,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -197,37 +238,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
];
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
}
else
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
output_reg
[
i
][
it
]
=
acc_t
(
0
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
}
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
acc_t
sum
[
WARP_BATCH
];
...
@@ -247,11 +288,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -247,11 +288,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
if
(
element_index
<
element_count
)
{
// compute gradients
// compute gradients
gradInput
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment