Unverified Commit 4068f4b5 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

[MISC] Replace c10::optional with std::optional (#11730)


Signed-off-by: default avatarLu Fang <lufang@fb.com>
parent 47831430
...@@ -928,7 +928,7 @@ void paged_attention_custom_launcher( ...@@ -928,7 +928,7 @@ void paged_attention_custom_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens, torch::Tensor& block_tables, torch::Tensor& context_lens,
int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes, int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
float k_scale, float v_scale) { float k_scale, float v_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
...@@ -1086,7 +1086,7 @@ void paged_attention( ...@@ -1086,7 +1086,7 @@ void paged_attention(
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& context_lens, // [num_seqs]
int64_t block_size, int64_t max_context_len, int64_t block_size, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale) { const std::string& kv_cache_dtype, double k_scale, double v_scale) {
const int head_size = query.size(2); const int head_size = query.size(2);
if (kv_cache_dtype == "auto") { if (kv_cache_dtype == "auto") {
......
...@@ -9,6 +9,6 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, ...@@ -9,6 +9,6 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
double scale, torch::Tensor& block_tables, double scale, torch::Tensor& block_tables,
torch::Tensor& context_lens, int64_t block_size, torch::Tensor& context_lens, int64_t block_size,
int64_t max_context_len, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, const std::string& kv_cache_dtype, double k_scale,
double v_scale); double v_scale);
...@@ -286,7 +286,7 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a, ...@@ -286,7 +286,7 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_meta, torch::Tensor const& bt_meta,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) { if (bias) {
......
...@@ -22,7 +22,7 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -22,7 +22,7 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& e, torch::Tensor const& e,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#endif #endif
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
...@@ -30,7 +30,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, ...@@ -30,7 +30,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& bt_meta, torch::Tensor const& bt_meta,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
// Checks for conformality // Checks for conformality
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2); TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) && TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment