Unverified Commit b09ff7e9 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

[pyTorch] Fix the compilation warnings (#2663)



* Fix the compilation warnings for the PyTorch extension
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Apply suggestion from @greptile-apps[bot]
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent 2894e493
......@@ -313,6 +313,9 @@ struct GroupedTensor {
SimpleTensor columnwise_amax;
SimpleTensor scale; // for FP8-DS only
NVTEScalingMode scaling_mode;
size_t num_tensors;
// Shape information (OPTIONAL - empty if dimension is uniform across all tensors)
// first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim)
// last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim)
......@@ -330,8 +333,6 @@ struct GroupedTensor {
// Always 2D with positive dimensions
NVTEShape logical_shape;
NVTEScalingMode scaling_mode;
size_t num_tensors;
NVTEGroupedTensor nvte_tensor;
GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
......@@ -342,12 +343,12 @@ struct GroupedTensor {
amax(),
columnwise_amax(),
scale(),
scaling_mode(scaling_mode),
num_tensors(num_tensors),
first_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 1)),
scaling_mode(scaling_mode),
nvte_tensor(0) {}
explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }
......
......@@ -250,7 +250,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
fe::graph::SDPA_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_is_inference(false)
.set_generate_stats(generate_stats)
.set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
......
......@@ -1810,7 +1810,7 @@ void fused_attn_fp8_fwd_impl_v1(
fe::graph::SDPA_fp8_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_fp8_attributes()
.set_name("sdpa_fp8")
.set_is_inference(false)
.set_generate_stats(true)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
......
......@@ -548,6 +548,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
~CommOverlap() {}
using transformer_engine::CommOverlapCore::copy_into_buffer;
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
at::Tensor get_buffer(bool local_chunk = false,
......@@ -569,6 +570,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
~CommOverlapP2P() {}
using transformer_engine::CommOverlapP2PBase::copy_into_buffer;
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
at::Tensor get_buffer(bool local_chunk = false,
......
......@@ -492,8 +492,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0,
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("copy_into_buffer",
static_cast<void (CommOverlap::*)(const at::Tensor &, bool)>(
&CommOverlap::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlap::get_communication_stream);
......@@ -510,8 +512,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1,
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("copy_into_buffer",
static_cast<void (CommOverlapP2P::*)(const at::Tensor &, bool)>(
&CommOverlapP2P::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlapP2P::get_communication_stream);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment