Unverified Commit cf632499 authored by Yanan Cao's avatar Yanan Cao Committed by GitHub
Browse files

[Kernel] [Helion] [15/N] Split config files into per-platform files (#36698)


Signed-off-by: default avatarYanan Cao <gmagogsfm@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent a3774a81
......@@ -160,10 +160,11 @@ class TestConfigManager:
"""Test getting config file path for a kernel."""
manager = ConfigManager(base_dir="/tmp")
file_path = manager.get_config_file_path("silu_mul_fp8")
dir_path = manager.get_config_file_path("silu_mul_fp8")
assert dir_path == Path("/tmp/silu_mul_fp8")
expected_path = Path("/tmp/silu_mul_fp8.json")
assert file_path == expected_path
file_path = manager.get_config_file_path("silu_mul_fp8", "nvidia_h100")
assert file_path == Path("/tmp/silu_mul_fp8/nvidia_h100.json")
def test_ensure_base_dir_exists(self):
"""Test ensuring base directory exists."""
......@@ -189,19 +190,19 @@ class TestConfigManager:
assert config_set.get_platforms() == []
def test_load_config_set_valid_file(self):
"""Test loading config set from valid file."""
"""Test loading config set from per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [128, 64],
"num_warps": 8,
"num_stages": 6,
"pid_type": "persistent_interleaved",
}
config_data = {"h100": {"batch_32_hidden_4096": kernel_config}}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
platform_file = kernel_dir / "h100.json"
with open(platform_file, "w") as f:
json.dump({"batch_32_hidden_4096": kernel_config}, f)
manager = ConfigManager(base_dir=temp_dir)
config_set = manager.load_config_set("test_kernel")
......@@ -210,7 +211,6 @@ class TestConfigManager:
assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == ["h100"]
# Verify the config was loaded correctly
config = config_set.get_config("h100", "batch_32_hidden_4096")
assert isinstance(config, helion.Config)
assert config.block_sizes == [128, 64]
......@@ -219,7 +219,9 @@ class TestConfigManager:
def test_load_config_set_invalid_json(self):
"""Test loading config set from file with invalid JSON."""
with tempfile.TemporaryDirectory() as temp_dir:
config_file = Path(temp_dir) / "test_kernel.json"
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
config_file = kernel_dir / "h100.json"
with open(config_file, "w") as f:
f.write("invalid json content {")
......@@ -231,9 +233,8 @@ class TestConfigManager:
assert config_set.get_platforms() == []
def test_save_config_set(self):
"""Test saving ConfigSet to file."""
"""Test saving ConfigSet to per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [256, 128],
"num_warps": 16,
......@@ -246,31 +247,34 @@ class TestConfigManager:
manager = ConfigManager(base_dir=temp_dir)
saved_path = manager.save_config_set(config_set)
expected_path = Path(temp_dir) / "test_kernel.json"
assert saved_path == expected_path
assert saved_path.exists()
expected_dir = Path(temp_dir) / "test_kernel"
assert saved_path == expected_dir
assert saved_path.is_dir()
with open(saved_path) as f:
platform_file = expected_dir / "h100.json"
assert platform_file.exists()
with open(platform_file) as f:
loaded_data = json.load(f)
assert loaded_data == data
assert loaded_data == data["h100"]
def test_save_config_set_creates_directory(self):
"""Test that save_config_set creates parent directories if needed."""
with tempfile.TemporaryDirectory() as temp_dir:
nested_dir = Path(temp_dir) / "nested" / "configs"
config_set = ConfigSet("test_kernel")
data = {"h100": {"default": {"num_warps": 4}}}
config_set = ConfigSet.from_dict("test_kernel", data)
manager = ConfigManager(base_dir=nested_dir)
saved_path = manager.save_config_set(config_set)
assert nested_dir.exists()
assert nested_dir.is_dir()
assert saved_path.exists()
assert saved_path.is_dir()
assert (saved_path / "h100.json").exists()
def test_get_platform_configs(self):
"""Test getting all configs for a specific platform."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]}
config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]}
default_config = {
......@@ -280,17 +284,19 @@ class TestConfigManager:
}
config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]}
config_data = {
"h100": {
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
"a100": {"batch_16_hidden_1024": config_3},
}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
with open(kernel_dir / "h100.json", "w") as f:
json.dump(
{
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
f,
)
with open(kernel_dir / "a100.json", "w") as f:
json.dump({"batch_16_hidden_1024": config_3}, f)
manager = ConfigManager(base_dir=temp_dir)
......@@ -302,7 +308,6 @@ class TestConfigManager:
for config in h100_configs.values():
assert isinstance(config, helion.Config)
# Verify specific config details
assert h100_configs["batch_32_hidden_4096"].num_warps == 4
assert h100_configs["default"].num_stages == 7
......
......@@ -8,23 +8,15 @@ operations, including naming conventions, directory resolution, and file I/O.
Config File Structure
---------------------
Each kernel has a single JSON config file: {kernel_name}.json
The file uses a simplified 2-layer hierarchical structure:
{
"h100": { # GPU platform
"default": { ... }, # Fallback configuration
"batch_32_hidden_4096": { ... },
"batch_64_hidden_8192": { ... }
},
"a100": {
"default": { ... },
"batch_16_hidden_2048": { ... }
}
}
Example file: silu_mul_fp8.json
Each kernel has a directory: {kernel_name}/
Inside, each GPU platform has its own JSON file: {kernel_name}/{platform}.json
For example:
silu_mul_fp8/
nvidia_h100.json # { "default": {...}, "batch_32_hidden_4096": {...} }
nvidia_h200.json # { "batch_16_hidden_2048": {...} }
Each platform file maps config keys to Helion config objects.
Config keys should be structured strings that encode the relevant
parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.).
......@@ -212,8 +204,15 @@ class ConfigManager:
cls._instance = None
cls._instance_base_dir = None
def get_config_file_path(self, kernel_name: str) -> Path:
return self._base_dir / f"{kernel_name}.json"
def get_kernel_dir(self, kernel_name: str) -> Path:
return self._base_dir / kernel_name
def get_config_file_path(
self, kernel_name: str, platform: str | None = None
) -> Path:
if platform is not None:
return self.get_kernel_dir(kernel_name) / f"{platform}.json"
return self.get_kernel_dir(kernel_name)
def ensure_base_dir_exists(self) -> Path:
self._base_dir.mkdir(parents=True, exist_ok=True)
......@@ -230,39 +229,59 @@ class ConfigManager:
f"Config directory '{self._base_dir}' is not writable: {e}"
) from e
def load_config_set(self, kernel_name: str) -> ConfigSet:
config_path = self.get_config_file_path(kernel_name)
def _load_platform_file(self, kernel_name: str, platform: str) -> dict[str, Any]:
config_path = self.get_config_file_path(kernel_name, platform)
if not config_path.exists():
return ConfigSet.from_dict(kernel_name, {})
return {}
try:
with open(config_path) as f:
data = json.load(f)
return ConfigSet.from_dict(kernel_name, data)
return json.load(f)
except (json.JSONDecodeError, OSError) as e:
logger.error("Failed to load config file %s: %s", config_path, e)
return {}
def load_config_set(self, kernel_name: str) -> ConfigSet:
kernel_dir = self.get_kernel_dir(kernel_name)
if not kernel_dir.is_dir():
return ConfigSet.from_dict(kernel_name, {})
data: dict[str, Any] = {}
for platform_file in sorted(kernel_dir.glob("*.json")):
platform = platform_file.stem
try:
with open(platform_file) as f:
platform_data = json.load(f)
data[platform] = platform_data
except (json.JSONDecodeError, OSError) as e:
logger.error("Failed to load config file %s: %s", platform_file, e)
return ConfigSet.from_dict(kernel_name, data)
def get_platform_configs(
self, kernel_name: str, platform: str
) -> dict[str, helion.Config]:
config_set = self.load_config_set(kernel_name)
platform_data = self._load_platform_file(kernel_name, platform)
if not platform_data:
return {}
config_set = ConfigSet.from_dict(kernel_name, {platform: platform_data})
config_keys = config_set.get_config_keys(platform)
return {
config_key: config_set.get_config(platform, config_key)
for config_key in config_keys
}
def save_config_set(self, config_set: ConfigSet) -> Path:
config_path = self.get_config_file_path(config_set.kernel_name)
config_path.parent.mkdir(parents=True, exist_ok=True)
kernel_dir = self.get_kernel_dir(config_set.kernel_name)
kernel_dir.mkdir(parents=True, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config_set.to_dict(), f, indent=2)
full_data = config_set.to_dict()
for platform, platform_data in full_data.items():
platform_path = kernel_dir / f"{platform}.json"
with open(platform_path, "w") as f:
json.dump(platform_data, f, indent=2)
logger.info("Saved config to: %s", platform_path)
logger.info("Saved config to: %s", config_path)
return config_path
return kernel_dir
def save_configs(
self,
......@@ -271,11 +290,18 @@ class ConfigManager:
configs: dict[str, "helion.Config"],
) -> Path:
"""Save configs for a kernel/platform, merging with existing."""
config_set = self.load_config_set(kernel_name)
platform_data = self._load_platform_file(kernel_name, platform)
for config_key, config in configs.items():
config_set.set_config(platform, config_key, config)
return self.save_config_set(config_set)
platform_data[config_key] = json.loads(config.to_json())
platform_path = self.get_config_file_path(kernel_name, platform)
platform_path.parent.mkdir(parents=True, exist_ok=True)
with open(platform_path, "w") as f:
json.dump(platform_data, f, indent=2)
logger.info("Saved config to: %s", platform_path)
return platform_path
def config_exists(self, kernel_name: str, platform: str, config_key: str) -> bool:
config_set = self.load_config_set(kernel_name)
return config_set.has_config(platform, config_key)
platform_data = self._load_platform_file(kernel_name, platform)
return config_key in platform_data
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment