Unverified Commit 5c1e496a authored by Shiyan Deng's avatar Shiyan Deng Committed by GitHub
Browse files

[MISC] replace c10::optional with std::optional (#25602)


Signed-off-by: default avatarShiyan Deng <dsy842974287@meta.com>
parent e7f27ea6
...@@ -6,11 +6,11 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, ...@@ -6,11 +6,11 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block); const int64_t rows_per_block);
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias, const std::optional<at::Tensor>& in_bias,
const int64_t CuCount); const int64_t CuCount);
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c, const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b, const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount); const int64_t CuCount);
......
...@@ -1271,7 +1271,7 @@ int mindiv(int N, int div1, int div2) { ...@@ -1271,7 +1271,7 @@ int mindiv(int N, int div1, int div2) {
} }
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias, const std::optional<at::Tensor>& in_bias,
const int64_t CuCount) { const int64_t CuCount) {
auto M_in = in_a.size(0); auto M_in = in_a.size(0);
auto K_in = in_a.size(1); auto K_in = in_a.size(1);
...@@ -1729,7 +1729,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, ...@@ -1729,7 +1729,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support #endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c, const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b, const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount) { const int64_t CuCount) {
static c10::ScalarType kFp8Type = is_fp8_ocp() static c10::ScalarType kFp8Type = is_fp8_ocp()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment