Unverified Commit 68adf451 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Convert non-kernel cuda files to cpp (#1322)



* Fix file extensions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* upgrade paddle container for CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent bfddb483
...@@ -76,7 +76,7 @@ jobs: ...@@ -76,7 +76,7 @@ jobs:
name: 'PaddlePaddle' name: 'PaddlePaddle'
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
image: nvcr.io/nvidia/paddlepaddle:24.07-py3 image: nvcr.io/nvidia/paddlepaddle:24.10-py3
options: --user root options: --user root
steps: steps:
- name: 'Checkout' - name: 'Checkout'
......
...@@ -25,7 +25,7 @@ def setup_paddle_extension( ...@@ -25,7 +25,7 @@ def setup_paddle_extension(
# Source files # Source files
csrc_source_files = Path(csrc_source_files) csrc_source_files = Path(csrc_source_files)
sources = [ sources = [
csrc_source_files / "extensions.cu", csrc_source_files / "extensions.cpp",
csrc_source_files / "common.cpp", csrc_source_files / "common.cpp",
csrc_source_files / "custom_ops.cu", csrc_source_files / "custom_ops.cu",
] ]
......
...@@ -26,7 +26,7 @@ def setup_pytorch_extension( ...@@ -26,7 +26,7 @@ def setup_pytorch_extension(
csrc_source_files = Path(csrc_source_files) csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions" extensions_dir = csrc_source_files / "extensions"
sources = [ sources = [
csrc_source_files / "common.cu", csrc_source_files / "common.cpp",
csrc_source_files / "ts_fp8_op.cpp", csrc_source_files / "ts_fp8_op.cpp",
] + all_files_in_dir(extensions_dir) ] + all_files_in_dir(extensions_dir)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
************************************************************************/ ************************************************************************/
#include "common.h" #include "common.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment