Unverified Commit dd083bdf authored by ldl's avatar ldl Committed by GitHub
Browse files

[PyTorch] Fix numeric overflow caused by int-type parameters and return value...


[PyTorch] Fix numeric overflow caused by int-type parameters and return value in the roundup function (#2034)
Signed-off-by: default avatarlvdunlin <lvdunlin@xiaomi.com>
Co-authored-by: default avatarlvdunlin <lvdunlin@xiaomi.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent bfab8c67
...@@ -286,7 +286,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape) { ...@@ -286,7 +286,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim); return std::vector<size_t>(shape.data, shape.data + shape.ndim);
} }
int roundup(const int value, const int multiple) { size_t roundup(const size_t value, const size_t multiple) {
assert(multiple > 0); assert(multiple > 0);
return ((value + multiple - 1) / multiple) * multiple; return ((value + multiple - 1) / multiple) * multiple;
} }
......
...@@ -417,7 +417,7 @@ void* getDataPtr(at::Tensor tensor, int offset = 0); ...@@ -417,7 +417,7 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);
std::vector<size_t> convertShape(const NVTEShape& shape); std::vector<size_t> convertShape(const NVTEShape& shape);
int roundup(const int value, const int multiple); size_t roundup(const size_t value, const size_t multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment