Commit c6b77639 authored by hyunwoongko's avatar hyunwoongko Committed by mshoeybi
Browse files

chagne PR by reviews

parent 0d350c8d
...@@ -339,6 +339,7 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -339,6 +339,7 @@ void dispatch_scaled_masked_softmax_forward(
int attn_heads, int attn_heads,
int pad_batches) int pad_batches)
{ {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
...@@ -357,6 +358,7 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -357,6 +358,7 @@ void dispatch_scaled_masked_softmax_forward(
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
...@@ -426,6 +428,7 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -426,6 +428,7 @@ void dispatch_scaled_masked_softmax_backward(
int batches, int batches,
int attn_heads) int attn_heads)
{ {
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
......
...@@ -340,6 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( ...@@ -340,6 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int softmax_elements_stride, int softmax_elements_stride,
int attn_batches) int attn_batches)
{ {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) { if (softmax_elements == 0) {
return; return;
} else { } else {
...@@ -359,6 +360,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( ...@@ -359,6 +360,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block; int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1); dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
...@@ -428,6 +431,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( ...@@ -428,6 +431,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int softmax_elements_stride, int softmax_elements_stride,
int attn_batches) int attn_batches)
{ {
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) { if (softmax_elements == 0) {
return; return;
} else { } else {
...@@ -447,6 +451,8 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( ...@@ -447,6 +451,8 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block; int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1); dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
......
...@@ -138,8 +138,8 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -138,8 +138,8 @@ class FusedScaleMaskSoftmax(nn.Module):
self.scaled_masked_softmax_fusion # user want to fuse self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None and mask is not None # mask tensor must not be None
and 16 < sq <= 2048 # sq must be 16 ~ 2048 and 16 < sk <= 2048 # sq must be 16 ~ 2048
and sk % 4 == 0 # sk must be divisor of 4 and sq % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
): ):
if 0 <= sk <= 2048: if 0 <= sk <= 2048:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment