Commit 145787ae authored by zhuwenwen's avatar zhuwenwen
Browse files

fix merge

parent 408f0a79
......@@ -33,11 +33,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
<<<<<<< HEAD
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] if not is_hip() else ["auto"]
=======
KV_CACHE_DTYPE = ["auto", "fp8"]
>>>>>>> v0.4.1
KV_CACHE_DTYPE = ["auto", "fp8"] if not is_hip() else ["auto"]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
......
......@@ -25,11 +25,7 @@ SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
<<<<<<< HEAD
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] if not is_hip() else ["auto"]
=======
KV_CACHE_DTYPE = ["auto", "fp8"]
>>>>>>> v0.4.1
KV_CACHE_DTYPE = ["auto", "fp8"] if not is_hip() else ["auto"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
......
......@@ -346,13 +346,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
<<<<<<< HEAD
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method,
=======
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, quant_config,
>>>>>>> v0.4.2
self.output_sizes)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
......@@ -514,12 +509,8 @@ class QKVParallelLinear(ColumnParallelLinear):
]
super().__init__(input_size, output_size, bias, False, skip_bias_add,
<<<<<<< HEAD
params_dtype, linear_method, output_sizes)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
=======
params_dtype, quant_config, output_sizes)
>>>>>>> v0.4.2
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self,
param: Parameter,
......
......@@ -54,18 +54,12 @@ def _get_quantization_config(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
<<<<<<< HEAD
linear_method = quant_config.get_linear_method()
return quant_config
if linear_method != None:
if quant_config != None:
os.environ['LLAMA_NN'] = '0'
return linear_method
=======
return quant_config
return None
>>>>>>> v0.4.2
def _get_model_initialization_kwargs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment