Unverified Commit 9765b5c4 authored by Hongxia Yang's avatar Hongxia Yang Committed by GitHub
Browse files

[ROCm][Bugfix] Fixed several bugs related to rccl path and attention selector logic (#3699)

parent 430530fc
...@@ -90,6 +90,6 @@ RUN cd /app \ ...@@ -90,6 +90,6 @@ RUN cd /app \
&& cd .. && cd ..
RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir ray[all] RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
CMD ["/bin/bash"] CMD ["/bin/bash"]
...@@ -5,7 +5,7 @@ starlette ...@@ -5,7 +5,7 @@ starlette
requests requests
py-cpuinfo py-cpuinfo
psutil psutil
ray >= 2.9 ray == 2.9.3
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
tokenizers>=0.15.0 tokenizers>=0.15.0
......
...@@ -405,8 +405,8 @@ def _check_use_naive_attention() -> bool: ...@@ -405,8 +405,8 @@ def _check_use_naive_attention() -> bool:
if not is_hip(): if not is_hip():
return False return False
# For ROCm, check whether flash attention is installed or not. # For ROCm, check whether flash attention is installed or not.
has_flash_attn = importlib.util.find_spec("flash_attn") is None use_naive_attention = importlib.util.find_spec("flash_attn") is None
if not has_flash_attn: if use_naive_attention:
logger.warning("flash_attn is not installed. Using naive attention. " logger.warning("flash_attn is not installed. Using naive attention. "
"This will take significantly more GPU memory.") "This will take significantly more GPU memory.")
return True return True
......
...@@ -41,7 +41,7 @@ else: ...@@ -41,7 +41,7 @@ else:
if torch.version.cuda is not None: if torch.version.cuda is not None:
so_file = "libnccl.so.2" so_file = "libnccl.so.2"
elif torch.version.hip is not None: elif torch.version.hip is not None:
so_file = "librccl.so.2" so_file = "librccl.so.1"
else: else:
raise ValueError("NCCL only supports CUDA and ROCm backends.") raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.debug(f"Loading nccl from library {so_file}") logger.debug(f"Loading nccl from library {so_file}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment