Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
cf0b0f01
Commit
cf0b0f01
authored
Dec 09, 2021
by
Hubert Lu
Browse files
Fix some bugs related to THCState and cutlass
parent
9615983e
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
66 additions
and
94 deletions
+66
-94
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+7
-12
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+7
-12
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+6
-12
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+7
-12
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+7
-12
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+7
-12
apex/contrib/csrc/multihead_attn/softmax.h
apex/contrib/csrc/multihead_attn/softmax.h
+19
-18
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
+6
-4
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
cf0b0f01
...
@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
));
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -194,8 +193,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -194,8 +193,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
}
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -371,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -371,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
));
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -394,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -394,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -434,8 +430,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -434,8 +430,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -457,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -457,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -595,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -595,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace encdec
}
// end namespace encdec
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
cf0b0f01
...
@@ -166,8 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -166,8 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
solution_index
,
solution_index
,
flags
));
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -220,8 +219,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -220,8 +219,7 @@ std::vector<torch::Tensor> fwd_cuda(
}
}
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -435,8 +433,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -435,8 +433,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
));
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -458,8 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -458,8 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -498,8 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -498,8 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -521,8 +516,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -521,8 +516,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -675,3 +669,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -675,3 +669,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
cf0b0f01
...
@@ -116,8 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -116,8 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags
));
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -162,8 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -162,8 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
}
}
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -388,8 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -388,8 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
stream
);
stream
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -411,8 +406,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -411,8 +406,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
cf0b0f01
...
@@ -108,8 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -108,8 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
flags
));
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -162,8 +161,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -162,8 +161,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
}
}
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -327,8 +325,7 @@ std::vector<torch::Tensor> bwd_cuda(
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -350,8 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -383,8 +379,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -383,8 +379,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
*
q_seq_len
,
stream
);
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -406,8 +401,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -406,8 +401,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -489,3 +483,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -489,3 +483,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
cf0b0f01
...
@@ -106,8 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -106,8 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
));
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -160,8 +159,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -160,8 +159,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
}
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -322,8 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -322,8 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
));
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -345,8 +342,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -345,8 +342,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -385,8 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -385,8 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -408,8 +403,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -408,8 +403,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -493,3 +487,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -493,3 +487,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
cf0b0f01
...
@@ -128,8 +128,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -128,8 +128,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
));
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -182,8 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -182,8 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
}
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -380,8 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -380,8 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
));
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -403,8 +400,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -403,8 +400,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -443,8 +439,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -443,8 +439,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -466,8 +461,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -466,8 +461,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -565,3 +559,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -565,3 +559,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace self_norm_add
}
// end namespace self_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/softmax.h
View file @
cf0b0f01
...
@@ -161,7 +161,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src,
...
@@ -161,7 +161,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src,
float
val
[
WARP_BATCH
];
float
val
[
WARP_BATCH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
@@ -186,7 +186,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src,
...
@@ -186,7 +186,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src,
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
}
}
...
@@ -402,7 +402,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(
...
@@ -402,7 +402,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(
float
val
[
WARP_BATCH
];
float
val
[
WARP_BATCH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
@@ -426,7 +426,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(
...
@@ -426,7 +426,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
}
}
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
philox_args
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
philox_args
);
...
@@ -564,7 +564,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(
...
@@ -564,7 +564,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(
float
val
[
WARP_BATCH
];
float
val
[
WARP_BATCH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
@@ -588,7 +588,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(
...
@@ -588,7 +588,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
}
}
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
...
@@ -874,7 +874,7 @@ __global__ void additive_masked_softmax_warp_forward(
...
@@ -874,7 +874,7 @@ __global__ void additive_masked_softmax_warp_forward(
float
val
[
WARP_BATCH
];
float
val
[
WARP_BATCH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
@@ -899,7 +899,7 @@ __global__ void additive_masked_softmax_warp_forward(
...
@@ -899,7 +899,7 @@ __global__ void additive_masked_softmax_warp_forward(
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
}
}
...
@@ -1164,7 +1164,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src,
...
@@ -1164,7 +1164,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src,
float
val
[
WARP_BATCH
];
float
val
[
WARP_BATCH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
@@ -1189,7 +1189,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src,
...
@@ -1189,7 +1189,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src,
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
}
}
...
@@ -1414,7 +1414,7 @@ __global__ void time_masked_softmax_warp_forward(
...
@@ -1414,7 +1414,7 @@ __global__ void time_masked_softmax_warp_forward(
float
val
[
WARP_BATCH
];
float
val
[
WARP_BATCH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
@@ -1439,7 +1439,7 @@ __global__ void time_masked_softmax_warp_forward(
...
@@ -1439,7 +1439,7 @@ __global__ void time_masked_softmax_warp_forward(
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
}
}
...
@@ -1586,13 +1586,13 @@ int log2_ceil_native(int value) {
...
@@ -1586,13 +1586,13 @@ int log2_ceil_native(int value) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
{
#if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
#if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
>
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
>
__device__
__forceinline__
void
warp_reduce_sum
(
acc_t
*
sum
)
{
__device__
__forceinline__
void
warp_reduce_sum
(
acc_t
*
sum
)
{
...
@@ -2149,7 +2149,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(
...
@@ -2149,7 +2149,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(
float
val
[
WARP_BATCH
];
float
val
[
WARP_BATCH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
@@ -2174,7 +2174,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(
...
@@ -2174,7 +2174,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
}
}
...
@@ -2754,7 +2754,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad,
...
@@ -2754,7 +2754,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad,
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
}
}
...
@@ -2988,7 +2988,7 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad,
...
@@ -2988,7 +2988,7 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad,
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
}
}
...
@@ -3137,3 +3137,4 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
...
@@ -3137,3 +3137,4 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
}
}
return
false
;
return
false
;
}
}
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
View file @
cf0b0f01
...
@@ -10,9 +10,9 @@
...
@@ -10,9 +10,9 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#include "cutlass/cutlass.h"
//
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
//
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
//
#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
// symbol to be automatically resolved by PyTorch libs
...
@@ -110,7 +110,8 @@ void HgemmStridedBatched(char transa, char transb, long m,
...
@@ -110,7 +110,8 @@ void HgemmStridedBatched(char transa, char transb, long m,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
)
{
if
((
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
if
((
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
))
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
))
...
@@ -129,3 +130,4 @@ void HgemmStridedBatched(char transa, char transb, long m,
...
@@ -129,3 +130,4 @@ void HgemmStridedBatched(char transa, char transb, long m,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
);
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment