Unverified Commit e0da01ea authored by xcnick's avatar xcnick Committed by GitHub
Browse files

[hotfix] fix build error when torch version >= 1.13 (#1803)

parent f5a92c28
......@@ -2,8 +2,13 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/torch.h>
#if TORCH_VERSION_MINOR >= 13
#include <torch/csrc/distributed/c10d/Types.hpp>
#else
#include <c10d/Types.hpp>
#endif
#include <iostream>
#include "context.h"
......
......@@ -4,8 +4,14 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <torch/torch.h>
#if TORCH_VERSION_MINOR >= 13
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#else
#include <c10d/ProcessGroup.hpp>
#endif
#include <string>
#include <type_traits>
......@@ -157,4 +163,4 @@ class MultiHeadAttention {
c10::intrusive_ptr<c10d::ProcessGroup> pg;
int pg_size;
};
\ No newline at end of file
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment