".github/vscode:/vscode.git/clone" did not exist on "b755e404a1544409e4dbab3db1a7d5eb11fa2566"
Unverified Commit b09ff7e9 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

[pyTorch] Fix the compilation warnings (#2663)



* Fix the compilation warnings for the PyTorch extension
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Apply suggestion from @greptile-apps[bot]
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent 2894e493
...@@ -313,6 +313,9 @@ struct GroupedTensor { ...@@ -313,6 +313,9 @@ struct GroupedTensor {
SimpleTensor columnwise_amax; SimpleTensor columnwise_amax;
SimpleTensor scale; // for FP8-DS only SimpleTensor scale; // for FP8-DS only
NVTEScalingMode scaling_mode;
size_t num_tensors;
// Shape information (OPTIONAL - empty if dimension is uniform across all tensors) // Shape information (OPTIONAL - empty if dimension is uniform across all tensors)
// first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim) // first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim)
// last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim) // last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim)
...@@ -330,8 +333,6 @@ struct GroupedTensor { ...@@ -330,8 +333,6 @@ struct GroupedTensor {
// Always 2D with positive dimensions // Always 2D with positive dimensions
NVTEShape logical_shape; NVTEShape logical_shape;
NVTEScalingMode scaling_mode;
size_t num_tensors;
NVTEGroupedTensor nvte_tensor; NVTEGroupedTensor nvte_tensor;
GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
...@@ -342,12 +343,12 @@ struct GroupedTensor { ...@@ -342,12 +343,12 @@ struct GroupedTensor {
amax(), amax(),
columnwise_amax(), columnwise_amax(),
scale(), scale(),
scaling_mode(scaling_mode),
num_tensors(num_tensors), num_tensors(num_tensors),
first_dims(nullptr, std::vector<size_t>{0}, DType::kInt64), first_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64), last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64), tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 1)), logical_shape(nvte_make_shape(nullptr, 1)),
scaling_mode(scaling_mode),
nvte_tensor(0) {} nvte_tensor(0) {}
explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }
......
...@@ -250,7 +250,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -250,7 +250,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
fe::graph::SDPA_attributes sdpa_options; fe::graph::SDPA_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_attributes() sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention") .set_name("flash_attention")
.set_is_inference(false)
.set_generate_stats(generate_stats) .set_generate_stats(generate_stats)
.set_causal_mask(is_causal) .set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right) .set_causal_mask_bottom_right(is_bottom_right)
......
...@@ -1810,7 +1810,7 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1810,7 +1810,7 @@ void fused_attn_fp8_fwd_impl_v1(
fe::graph::SDPA_fp8_attributes sdpa_options; fe::graph::SDPA_fp8_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_fp8_attributes() sdpa_options = fe::graph::SDPA_fp8_attributes()
.set_name("sdpa_fp8") .set_name("sdpa_fp8")
.set_is_inference(false) .set_generate_stats(true)
.set_causal_mask(is_causal) .set_causal_mask(is_causal)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
......
...@@ -548,6 +548,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve ...@@ -548,6 +548,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
~CommOverlap() {} ~CommOverlap() {}
using transformer_engine::CommOverlapCore::copy_into_buffer;
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
at::Tensor get_buffer(bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
...@@ -569,6 +570,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ...@@ -569,6 +570,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
~CommOverlapP2P() {} ~CommOverlapP2P() {}
using transformer_engine::CommOverlapP2PBase::copy_into_buffer;
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
at::Tensor get_buffer(bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
......
...@@ -492,8 +492,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -492,8 +492,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0,
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), .def("copy_into_buffer",
py::arg("local_chunk") = false) static_cast<void (CommOverlap::*)(const at::Tensor &, bool)>(
&CommOverlap::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt) py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlap::get_communication_stream); .def("get_communication_stream", &CommOverlap::get_communication_stream);
...@@ -510,8 +512,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -510,8 +512,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1,
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false) py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), .def("copy_into_buffer",
py::arg("local_chunk") = false) static_cast<void (CommOverlapP2P::*)(const at::Tensor &, bool)>(
&CommOverlapP2P::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt) py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlapP2P::get_communication_stream); .def("get_communication_stream", &CommOverlapP2P::get_communication_stream);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment