Unverified Commit f2d9d397 authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge pull request #172 from gongchensu/Issue/170

Issue/170 - Add HYGON support and improve device type handling.
parents d21a4f59 ed33c3a9
......@@ -47,6 +47,11 @@ def get_args():
action="store_true",
help="Run cambricon test",
)
parser.add_argument(
"--hygon",
action="store_true",
help="Run hygon test",
)
parser.add_argument(
"--model_path",
type=str,
......@@ -245,9 +250,11 @@ if __name__ == "__main__":
device_str = "cuda"
elif args.cambricon:
device_str = "mlu"
elif args.hygon:
device_str = "cuda"
else:
print(
"Usage: python examples/jiuge.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_dir>\n"
"Usage: python examples/jiuge.py [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon | --hygon] --model_path=<path/to/model_dir>\n"
"such as, python examples/jiuge.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
)
sys.exit(1)
......
......@@ -62,16 +62,18 @@ class InfiniLMBenchmark(BaseBenchmark):
self.benchmark = benchmark
# Map device type string to infinicore device
# Note: These map to the Python device type strings used by infinicore.device()
# which correspond to _TORCH_DEVICE_MAP values in InfiniCore/python/infinicore/device.py
device_map = {
"cpu": "cpu",
"nvidia": "cuda",
"cambricon": "mlu",
"ascend": "ascend",
"metax": "metax",
"moore": "moore",
"iluvatar": "iluvatar",
"kunlun": "kunlun",
"hygon": "hygon",
"ascend": "npu",
"metax": "cuda",
"moore": "musa",
"iluvatar": "cuda",
"kunlun": "cuda",
"hygon": "cuda",
}
device_name = device_map.get(device_type_str.lower(), "cpu")
......@@ -180,6 +182,13 @@ class InfiniLMBenchmark(BaseBenchmark):
start_time = time.perf_counter()
# For cpp backend, reset cache before generation if use_cache is enabled
if self.model.use_cache and hasattr(self.model, "_model") and hasattr(self.model._model, "reset_cache"):
batch_size = input_ids.shape[0]
seq_len = input_ids.shape[1]
max_cache_len = max_steps + seq_len
self.model.reset_cache(batch_size=batch_size, initial_capacity=max_cache_len)
# Use model's built-in generate() method which properly handles KV cache
# Pass sampling parameters (temperature, topk, topp) via kwargs
output_ids = self.model.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment