Unverified Commit a8f12a63 authored by Richard Liu's avatar Richard Liu Committed by GitHub
Browse files

Fix env vars for running Ray distributed backend on GKE (#15166)


Signed-off-by: default avatarRichard Liu <ricliu@google.com>
parent 69ae2380
......@@ -340,6 +340,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
and v not in self.non_carry_over_env_vars
]
env_vars_to_copy.extend(current_platform.additional_env_vars)
# Copy existing env vars to each worker's args
for args in all_args_to_update_environment_variables:
# TODO: refactor platform-specific env vars
......
......@@ -112,6 +112,8 @@ class Platform:
supported_quantization: list[str] = []
additional_env_vars: list[str] = []
def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
......
......@@ -29,6 +29,10 @@ class TpuPlatform(Platform):
"tpu_int8", "compressed-tensors", "compressed_tensors"
]
additional_env_vars: list[str] = [
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
]
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment