Unverified Commit cf632499 authored by Yanan Cao's avatar Yanan Cao Committed by GitHub
Browse files

[Kernel] [Helion] [15/N] Split config files into per-platform files (#36698)


Signed-off-by: default avatarYanan Cao <gmagogsfm@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent a3774a81
......@@ -160,10 +160,11 @@ class TestConfigManager:
"""Test getting config file path for a kernel."""
manager = ConfigManager(base_dir="/tmp")
file_path = manager.get_config_file_path("silu_mul_fp8")
dir_path = manager.get_config_file_path("silu_mul_fp8")
assert dir_path == Path("/tmp/silu_mul_fp8")
expected_path = Path("/tmp/silu_mul_fp8.json")
assert file_path == expected_path
file_path = manager.get_config_file_path("silu_mul_fp8", "nvidia_h100")
assert file_path == Path("/tmp/silu_mul_fp8/nvidia_h100.json")
def test_ensure_base_dir_exists(self):
"""Test ensuring base directory exists."""
......@@ -189,19 +190,19 @@ class TestConfigManager:
assert config_set.get_platforms() == []
def test_load_config_set_valid_file(self):
"""Test loading config set from valid file."""
"""Test loading config set from per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [128, 64],
"num_warps": 8,
"num_stages": 6,
"pid_type": "persistent_interleaved",
}
config_data = {"h100": {"batch_32_hidden_4096": kernel_config}}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
platform_file = kernel_dir / "h100.json"
with open(platform_file, "w") as f:
json.dump({"batch_32_hidden_4096": kernel_config}, f)
manager = ConfigManager(base_dir=temp_dir)
config_set = manager.load_config_set("test_kernel")
......@@ -210,7 +211,6 @@ class TestConfigManager:
assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == ["h100"]
# Verify the config was loaded correctly
config = config_set.get_config("h100", "batch_32_hidden_4096")
assert isinstance(config, helion.Config)
assert config.block_sizes == [128, 64]
......@@ -219,7 +219,9 @@ class TestConfigManager:
def test_load_config_set_invalid_json(self):
"""Test loading config set from file with invalid JSON."""
with tempfile.TemporaryDirectory() as temp_dir:
config_file = Path(temp_dir) / "test_kernel.json"
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
config_file = kernel_dir / "h100.json"
with open(config_file, "w") as f:
f.write("invalid json content {")
......@@ -231,9 +233,8 @@ class TestConfigManager:
assert config_set.get_platforms() == []
def test_save_config_set(self):
"""Test saving ConfigSet to file."""
"""Test saving ConfigSet to per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [256, 128],
"num_warps": 16,
......@@ -246,31 +247,34 @@ class TestConfigManager:
manager = ConfigManager(base_dir=temp_dir)
saved_path = manager.save_config_set(config_set)
expected_path = Path(temp_dir) / "test_kernel.json"
assert saved_path == expected_path
assert saved_path.exists()
expected_dir = Path(temp_dir) / "test_kernel"
assert saved_path == expected_dir
assert saved_path.is_dir()
with open(saved_path) as f:
platform_file = expected_dir / "h100.json"
assert platform_file.exists()
with open(platform_file) as f:
loaded_data = json.load(f)
assert loaded_data == data
assert loaded_data == data["h100"]
def test_save_config_set_creates_directory(self):
"""Test that save_config_set creates parent directories if needed."""
with tempfile.TemporaryDirectory() as temp_dir:
nested_dir = Path(temp_dir) / "nested" / "configs"
config_set = ConfigSet("test_kernel")
data = {"h100": {"default": {"num_warps": 4}}}
config_set = ConfigSet.from_dict("test_kernel", data)
manager = ConfigManager(base_dir=nested_dir)
saved_path = manager.save_config_set(config_set)
assert nested_dir.exists()
assert nested_dir.is_dir()
assert saved_path.exists()
assert saved_path.is_dir()
assert (saved_path / "h100.json").exists()
def test_get_platform_configs(self):
"""Test getting all configs for a specific platform."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]}
config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]}
default_config = {
......@@ -280,17 +284,19 @@ class TestConfigManager:
}
config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]}
config_data = {
"h100": {
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
"a100": {"batch_16_hidden_1024": config_3},
}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
with open(kernel_dir / "h100.json", "w") as f:
json.dump(
{
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
f,
)
with open(kernel_dir / "a100.json", "w") as f:
json.dump({"batch_16_hidden_1024": config_3}, f)
manager = ConfigManager(base_dir=temp_dir)
......@@ -302,7 +308,6 @@ class TestConfigManager:
for config in h100_configs.values():
assert isinstance(config, helion.Config)
# Verify specific config details
assert h100_configs["batch_32_hidden_4096"].num_warps == 4
assert h100_configs["default"].num_stages == 7
......
......@@ -8,23 +8,15 @@ operations, including naming conventions, directory resolution, and file I/O.
Config File Structure
---------------------
Each kernel has a single JSON config file: {kernel_name}.json
The file uses a simplified 2-layer hierarchical structure:
{
"h100": { # GPU platform
"default": { ... }, # Fallback configuration
"batch_32_hidden_4096": { ... },
"batch_64_hidden_8192": { ... }
},
"a100": {
"default": { ... },
"batch_16_hidden_2048": { ... }
}
}
Example file: silu_mul_fp8.json
Each kernel has a directory: {kernel_name}/
Inside, each GPU platform has its own JSON file: {kernel_name}/{platform}.json
For example:
silu_mul_fp8/
nvidia_h100.json # { "default": {...}, "batch_32_hidden_4096": {...} }
nvidia_h200.json # { "batch_16_hidden_2048": {...} }
Each platform file maps config keys to Helion config objects.
Config keys should be structured strings that encode the relevant
parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.).
......@@ -212,8 +204,15 @@ class ConfigManager:
cls._instance = None
cls._instance_base_dir = None
def get_config_file_path(self, kernel_name: str) -> Path:
return self._base_dir / f"{kernel_name}.json"
def get_kernel_dir(self, kernel_name: str) -> Path:
return self._base_dir / kernel_name
def get_config_file_path(
self, kernel_name: str, platform: str | None = None
) -> Path:
if platform is not None:
return self.get_kernel_dir(kernel_name) / f"{platform}.json"
return self.get_kernel_dir(kernel_name)
def ensure_base_dir_exists(self) -> Path:
self._base_dir.mkdir(parents=True, exist_ok=True)
......@@ -230,39 +229,59 @@ class ConfigManager:
f"Config directory '{self._base_dir}' is not writable: {e}"
) from e
def load_config_set(self, kernel_name: str) -> ConfigSet:
config_path = self.get_config_file_path(kernel_name)
def _load_platform_file(self, kernel_name: str, platform: str) -> dict[str, Any]:
config_path = self.get_config_file_path(kernel_name, platform)
if not config_path.exists():
return ConfigSet.from_dict(kernel_name, {})
return {}
try:
with open(config_path) as f:
data = json.load(f)
return ConfigSet.from_dict(kernel_name, data)
return json.load(f)
except (json.JSONDecodeError, OSError) as e:
logger.error("Failed to load config file %s: %s", config_path, e)
return {}
def load_config_set(self, kernel_name: str) -> ConfigSet:
kernel_dir = self.get_kernel_dir(kernel_name)
if not kernel_dir.is_dir():
return ConfigSet.from_dict(kernel_name, {})
data: dict[str, Any] = {}
for platform_file in sorted(kernel_dir.glob("*.json")):
platform = platform_file.stem
try:
with open(platform_file) as f:
platform_data = json.load(f)
data[platform] = platform_data
except (json.JSONDecodeError, OSError) as e:
logger.error("Failed to load config file %s: %s", platform_file, e)
return ConfigSet.from_dict(kernel_name, data)
def get_platform_configs(
self, kernel_name: str, platform: str
) -> dict[str, helion.Config]:
config_set = self.load_config_set(kernel_name)
platform_data = self._load_platform_file(kernel_name, platform)
if not platform_data:
return {}
config_set = ConfigSet.from_dict(kernel_name, {platform: platform_data})
config_keys = config_set.get_config_keys(platform)
return {
config_key: config_set.get_config(platform, config_key)
for config_key in config_keys
}
def save_config_set(self, config_set: ConfigSet) -> Path:
config_path = self.get_config_file_path(config_set.kernel_name)
config_path.parent.mkdir(parents=True, exist_ok=True)
kernel_dir = self.get_kernel_dir(config_set.kernel_name)
kernel_dir.mkdir(parents=True, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config_set.to_dict(), f, indent=2)
full_data = config_set.to_dict()
for platform, platform_data in full_data.items():
platform_path = kernel_dir / f"{platform}.json"
with open(platform_path, "w") as f:
json.dump(platform_data, f, indent=2)
logger.info("Saved config to: %s", platform_path)
logger.info("Saved config to: %s", config_path)
return config_path
return kernel_dir
def save_configs(
self,
......@@ -271,11 +290,18 @@ class ConfigManager:
configs: dict[str, "helion.Config"],
) -> Path:
"""Save configs for a kernel/platform, merging with existing."""
config_set = self.load_config_set(kernel_name)
platform_data = self._load_platform_file(kernel_name, platform)
for config_key, config in configs.items():
config_set.set_config(platform, config_key, config)
return self.save_config_set(config_set)
platform_data[config_key] = json.loads(config.to_json())
platform_path = self.get_config_file_path(kernel_name, platform)
platform_path.parent.mkdir(parents=True, exist_ok=True)
with open(platform_path, "w") as f:
json.dump(platform_data, f, indent=2)
logger.info("Saved config to: %s", platform_path)
return platform_path
def config_exists(self, kernel_name: str, platform: str, config_key: str) -> bool:
config_set = self.load_config_set(kernel_name)
return config_set.has_config(platform, config_key)
platform_data = self._load_platform_file(kernel_name, platform)
return config_key in platform_data
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment